diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000..2b65f6fe3cc80
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+*.bat text eol=crlf
+*.cmd text eol=crlf
diff --git a/.gitignore b/.gitignore
index 34939e3a97aaa..9757054a50f9e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,18 +5,22 @@
*.ipr
*.iml
*.iws
+*.pyc
.idea/
.idea_modules/
-sbt/*.jar
+build/*.jar
.settings
.cache
+cache
.generated-mima*
-/build/
work/
out/
.DS_Store
third_party/libmesos.so
third_party/libmesos.dylib
+build/apache-maven*
+build/zinc*
+build/scala*
conf/java-opts
conf/*.sh
conf/*.cmd
@@ -49,9 +53,12 @@ dependency-reduced-pom.xml
checkpoint
derby.log
dist/
-spark-*-bin.tar.gz
+dev/create-release/*txt
+dev/create-release/*final
+spark-*-bin-*.tgz
unit-tests.log
/lib/
+ec2/lib/
rat-results.txt
scalastyle.txt
scalastyle-output.xml
diff --git a/.rat-excludes b/.rat-excludes
index b14ad53720f32..769defbac11b7 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -1,5 +1,6 @@
target
.gitignore
+.gitattributes
.project
.classpath
.mima-excludes
@@ -43,11 +44,13 @@ SparkImports.scala
SparkJLineCompletion.scala
SparkJLineReader.scala
SparkMemberHandlers.scala
+SparkReplReporter.scala
sbt
sbt-launch-lib.bash
plugins.sbt
work
.*\.q
+.*\.qv
golden
test.out/*
.*iml
@@ -61,3 +64,4 @@ dist/*
logs
.*scalastyle-output.xml
.*dependency-reduced-pom.xml
+known_translations
diff --git a/LICENSE b/LICENSE
index a7eee041129cb..0a42d389e4c3c 100644
--- a/LICENSE
+++ b/LICENSE
@@ -646,7 +646,8 @@ THE SOFTWARE.
========================================================================
For Scala Interpreter classes (all .scala files in repl/src/main/scala
-except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala):
+except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala),
+and for SerializableMapWrapper in JavaUtils.scala:
========================================================================
Copyright (c) 2002-2013 EPFL
@@ -712,18 +713,6 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-========================================================================
-For colt:
-========================================================================
-
-Copyright (c) 1999 CERN - European Organization for Nuclear Research.
-Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose is hereby granted without fee, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation. CERN makes no representations about the suitability of this software for any purpose. It is provided "as is" without expressed or implied warranty.
-
-Packages hep.aida.*
-
-Written by Pavel Binko, Dino Ferrero Merlino, Wolfgang Hoschek, Tony Johnson, Andreas Pfeiffer, and others. Check the FreeHEP home page for more info. Permission to use and/or redistribute this work is granted under the terms of the LGPL License, with the exception that any usage related to military applications is expressly forbidden. The software and documentation made available under the terms of this license are provided with no warranty.
-
-
========================================================================
For SnapTree:
========================================================================
@@ -766,7 +755,7 @@ SUCH DAMAGE.
========================================================================
-For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java):
+For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java):
========================================================================
Copyright (C) 2008 The Android Open Source Project
@@ -783,6 +772,25 @@ See the License for the specific language governing permissions and
limitations under the License.
+========================================================================
+For LimitedInputStream
+ (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java):
+========================================================================
+Copyright (C) 2007 The Guava Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+
========================================================================
BSD-style licenses
========================================================================
diff --git a/README.md b/README.md
index dbf53dcd76b2d..af02339578195 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,8 @@ and Spark Streaming for stream processing.
## Online Documentation
You can find the latest Spark documentation, including a programming
-guide, on the [project web page](http://spark.apache.org/documentation.html).
+guide, on the [project web page](http://spark.apache.org/documentation.html)
+and [project wiki](https://cwiki.apache.org/confluence/display/SPARK).
This README file only contains basic setup instructions.
## Building Spark
@@ -25,7 +26,7 @@ To build Spark and its example programs, run:
(You do not need to do this if you downloaded a pre-built package.)
More detailed documentation is available from the project site, at
-["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-with-maven.html).
+["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html).
## Interactive Scala Shell
@@ -84,7 +85,7 @@ storage systems. Because the protocols have changed in different versions of
Hadoop, you must build Spark against the same version that your cluster runs.
Please refer to the build documentation at
-["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version)
+["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-with-maven.html#specifying-the-hadoop-version)
for detailed guidance on building for a particular distribution of Hadoop, including
building for particular Hive and Hive Thriftserver distributions. See also
["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html)
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 31a01e4d8e1de..1bb5a671f5390 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
@@ -43,12 +43,6 @@
-
-
- com.google.guava
- guava
- compile
- org.apache.sparkspark-core_${scala.binary.version}
@@ -66,22 +60,22 @@
org.apache.spark
- spark-repl_${scala.binary.version}
+ spark-streaming_${scala.binary.version}${project.version}org.apache.spark
- spark-streaming_${scala.binary.version}
+ spark-graphx_${scala.binary.version}${project.version}org.apache.spark
- spark-graphx_${scala.binary.version}
+ spark-sql_${scala.binary.version}${project.version}org.apache.spark
- spark-sql_${scala.binary.version}
+ spark-repl_${scala.binary.version}${project.version}
@@ -133,20 +127,6 @@
shade
-
-
- com.google
- org.spark-project.guava
-
- com.google.common.**
-
-
- com/google/common/base/Absent*
- com/google/common/base/Optional*
- com/google/common/base/Present*
-
-
-
@@ -169,16 +149,6 @@
-
- yarn-alpha
-
-
- org.apache.spark
- spark-yarn-alpha_${scala.binary.version}
- ${project.version}
-
-
- yarn
@@ -197,6 +167,11 @@
spark-hive_${scala.binary.version}${project.version}
+
+
+
+ hive-thriftserver
+ org.apache.sparkspark-hive-thriftserver_${scala.binary.version}
@@ -359,5 +334,25 @@
+
+
+
+ hadoop-provided
+
+ provided
+
+
+
+ hive-provided
+
+ provided
+
+
+
+ parquet-provided
+
+ provided
+
+
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 93db0d5efda5f..510e92640eff8 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
@@ -40,15 +40,6 @@
spark-core_${scala.binary.version}${project.version}
-
- org.eclipse.jetty
- jetty-server
-
-
- org.scalatest
- scalatest_${scala.binary.version}
- test
- org.scalacheckscalacheck_${scala.binary.version}
@@ -58,11 +49,5 @@
target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
-
-
- org.scalatest
- scalatest-maven-plugin
-
-
diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties
index 789869f72e3b0..853ef0ed2986f 100644
--- a/bagel/src/test/resources/log4j.properties
+++ b/bagel/src/test/resources/log4j.properties
@@ -15,10 +15,10 @@
# limitations under the License.
#
-# Set everything to be logged to the file bagel/target/unit-tests.log
+# Set everything to be logged to the file target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
-log4j.appender.file.append=false
+log4j.appender.file.append=true
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
diff --git a/bin/beeline.cmd b/bin/beeline.cmd
new file mode 100644
index 0000000000000..8293f311029dd
--- /dev/null
+++ b/bin/beeline.cmd
@@ -0,0 +1,21 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+set SPARK_HOME=%~dp0..
+cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.hive.beeline.BeeLine %*
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index 3cd0579aea8d3..088f993954d9e 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -1,117 +1,124 @@
-@echo off
-
-rem
-rem Licensed to the Apache Software Foundation (ASF) under one or more
-rem contributor license agreements. See the NOTICE file distributed with
-rem this work for additional information regarding copyright ownership.
-rem The ASF licenses this file to You under the Apache License, Version 2.0
-rem (the "License"); you may not use this file except in compliance with
-rem the License. You may obtain a copy of the License at
-rem
-rem http://www.apache.org/licenses/LICENSE-2.0
-rem
-rem Unless required by applicable law or agreed to in writing, software
-rem distributed under the License is distributed on an "AS IS" BASIS,
-rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-rem See the License for the specific language governing permissions and
-rem limitations under the License.
-rem
-
-rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
-rem script and the ExecutorRunner in standalone cluster mode.
-
-rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting
-rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we
-rem need to set it here because we use !datanucleus_jars! below.
-if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion
-setlocal enabledelayedexpansion
-:skip_delayed_expansion
-
-set SCALA_VERSION=2.10
-
-rem Figure out where the Spark framework is installed
-set FWDIR=%~dp0..\
-
-rem Load environment variables from conf\spark-env.cmd, if it exists
-if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
-
-rem Build up classpath
-set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
-
-if not "x%SPARK_CONF_DIR%"=="x" (
- set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
-) else (
- set CLASSPATH=%CLASSPATH%;%FWDIR%conf
-)
-
-if exist "%FWDIR%RELEASE" (
- for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-) else (
- for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-)
-
-set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
-
-rem When Hive support is needed, Datanucleus jars must be included on the classpath.
-rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost.
-rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is
-rem built with Hive, so look for them there.
-if exist "%FWDIR%RELEASE" (
- set datanucleus_dir=%FWDIR%lib
-) else (
- set datanucleus_dir=%FWDIR%lib_managed\jars
-)
-set "datanucleus_jars="
-for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do (
- set datanucleus_jars=!datanucleus_jars!;%%d
-)
-set CLASSPATH=%CLASSPATH%;%datanucleus_jars%
-
-set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
-
-set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
-
-if "x%SPARK_TESTING%"=="x1" (
- rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
- rem so that local compilation takes precedence over assembled jar
- set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH%
-)
-
-rem Add hadoop conf dir - else FileSystem.*, etc fail
-rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
-rem the configurtion files.
-if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
- set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
-:no_hadoop_conf_dir
-
-if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
- set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
-:no_yarn_conf_dir
-
-rem A bit of a hack to allow calling this script within run2.cmd without seeing output
-if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
-
-echo %CLASSPATH%
-
-:exit
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
+rem script and the ExecutorRunner in standalone cluster mode.
+
+rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting
+rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we
+rem need to set it here because we use !datanucleus_jars! below.
+if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion
+setlocal enabledelayedexpansion
+:skip_delayed_expansion
+
+set SCALA_VERSION=2.10
+
+rem Figure out where the Spark framework is installed
+set FWDIR=%~dp0..\
+
+rem Load environment variables from conf\spark-env.cmd, if it exists
+if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
+
+rem Build up classpath
+set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
+
+if not "x%SPARK_CONF_DIR%"=="x" (
+ set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
+) else (
+ set CLASSPATH=%CLASSPATH%;%FWDIR%conf
+)
+
+if exist "%FWDIR%RELEASE" (
+ for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+) else (
+ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+)
+
+set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
+
+rem When Hive support is needed, Datanucleus jars must be included on the classpath.
+rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost.
+rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is
+rem built with Hive, so look for them there.
+if exist "%FWDIR%RELEASE" (
+ set datanucleus_dir=%FWDIR%lib
+) else (
+ set datanucleus_dir=%FWDIR%lib_managed\jars
+)
+set "datanucleus_jars="
+for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do (
+ set datanucleus_jars=!datanucleus_jars!;%%d
+)
+set CLASSPATH=%CLASSPATH%;%datanucleus_jars%
+
+set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
+
+set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
+
+if "x%SPARK_TESTING%"=="x1" (
+ rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
+ rem so that local compilation takes precedence over assembled jar
+ set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH%
+)
+
+rem Add hadoop conf dir - else FileSystem.*, etc fail
+rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
+rem the configurtion files.
+if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
+ set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
+:no_hadoop_conf_dir
+
+if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
+ set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
+:no_yarn_conf_dir
+
+rem To allow for distributions to append needed libraries to the classpath (e.g. when
+rem using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and
+rem append it to tbe final classpath.
+if not "x%$SPARK_DIST_CLASSPATH%"=="x" (
+ set CLASSPATH=%CLASSPATH%;%SPARK_DIST_CLASSPATH%
+)
+
+rem A bit of a hack to allow calling this script within run2.cmd without seeing output
+if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
+
+echo %CLASSPATH%
+
+:exit
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 905bbaf99b374..a8c344b1ca594 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -20,14 +20,16 @@
# This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
# script and the ExecutorRunner in standalone cluster mode.
-SCALA_VERSION=2.10
-
# Figure out where Spark is installed
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
. "$FWDIR"/bin/load-spark-env.sh
-CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH"
+if [ -n "$SPARK_CLASSPATH" ]; then
+ CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH"
+else
+ CLASSPATH="$SPARK_SUBMIT_CLASSPATH"
+fi
# Build up classpath
if [ -n "$SPARK_CONF_DIR" ]; then
@@ -36,7 +38,7 @@ else
CLASSPATH="$CLASSPATH:$FWDIR/conf"
fi
-ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION"
+ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION"
if [ -n "$JAVA_HOME" ]; then
JAR_CMD="$JAVA_HOME/bin/jar"
@@ -48,19 +50,21 @@ fi
if [ -n "$SPARK_PREPEND_CLASSES" ]; then
echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\
"classes ahead of assembly." >&2
- CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
+ # Spark classes
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes"
+ # Jars for shaded deps in their original form (copied here during build)
CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*"
- CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes"
fi
# Use spark-assembly jar from either RELEASE or assembly directory
@@ -70,22 +74,25 @@ else
assembly_folder="$ASSEMBLY_DIR"
fi
-num_jars="$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l)"
-if [ "$num_jars" -eq "0" ]; then
- echo "Failed to find Spark assembly in $assembly_folder"
- echo "You need to build Spark before running this program."
- exit 1
-fi
+num_jars=0
+
+for f in ${assembly_folder}/spark-assembly*hadoop*.jar; do
+ if [[ ! -e "$f" ]]; then
+ echo "Failed to find Spark assembly in $assembly_folder" 1>&2
+ echo "You need to build Spark before running this program." 1>&2
+ exit 1
+ fi
+ ASSEMBLY_JAR="$f"
+ num_jars=$((num_jars+1))
+done
+
if [ "$num_jars" -gt "1" ]; then
- jars_list=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*.jar")
- echo "Found multiple Spark assembly jars in $assembly_folder:"
- echo "$jars_list"
- echo "Please remove all but one jar."
+ echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2
+ ls ${assembly_folder}/spark-assembly*hadoop*.jar 1>&2
+ echo "Please remove all but one jar." 1>&2
exit 1
fi
-ASSEMBLY_JAR="$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null)"
-
# Verify that versions of java used to build the jars and run Spark are compatible
jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1)
if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then
@@ -110,7 +117,7 @@ else
datanucleus_dir="$FWDIR"/lib_managed/jars
fi
-datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar")"
+datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar$")"
datanucleus_jars="$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g)"
if [ -n "$datanucleus_jars" ]; then
@@ -123,15 +130,15 @@ fi
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then
- CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes"
fi
# Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail !
@@ -144,4 +151,11 @@ if [ -n "$YARN_CONF_DIR" ]; then
CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
fi
+# To allow for distributions to append needed libraries to the classpath (e.g. when
+# using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and
+# append it to tbe final classpath.
+if [ -n "$SPARK_DIST_CLASSPATH" ]; then
+ CLASSPATH="$CLASSPATH:$SPARK_DIST_CLASSPATH"
+fi
+
echo "$CLASSPATH"
diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh
index 6d4231b204595..356b3d49b2ffe 100644
--- a/bin/load-spark-env.sh
+++ b/bin/load-spark-env.sh
@@ -36,3 +36,23 @@ if [ -z "$SPARK_ENV_LOADED" ]; then
set +a
fi
fi
+
+# Setting SPARK_SCALA_VERSION if not already set.
+
+if [ -z "$SPARK_SCALA_VERSION" ]; then
+
+ ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11"
+ ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10"
+
+ if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then
+ echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2
+ echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2
+ exit 1
+ fi
+
+ if [ -d "$ASSEMBLY_DIR2" ]; then
+ export SPARK_SCALA_VERSION="2.11"
+ else
+ export SPARK_SCALA_VERSION="2.10"
+ fi
+fi
diff --git a/bin/pyspark b/bin/pyspark
index 96f30a260a09e..0b4f695dd06dd 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -25,7 +25,7 @@ export SPARK_HOME="$FWDIR"
source "$FWDIR/bin/utils.sh"
-SCALA_VERSION=2.10
+source "$FWDIR"/bin/load-spark-env.sh
function usage() {
echo "Usage: ./bin/pyspark [options]" 1>&2
@@ -40,7 +40,7 @@ fi
# Exit if the user hasn't compiled Spark
if [ ! -f "$FWDIR/RELEASE" ]; then
# Exit if the user hasn't compiled Spark
- ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null
+ ls "$FWDIR"/assembly/target/scala-$SPARK_SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null
if [[ $? != 0 ]]; then
echo "Failed to find Spark assembly in $FWDIR/assembly/target" 1>&2
echo "You need to build Spark before running this program" 1>&2
@@ -48,8 +48,6 @@ if [ ! -f "$FWDIR/RELEASE" ]; then
fi
fi
-. "$FWDIR"/bin/load-spark-env.sh
-
# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython`
# executable, while the worker would still be launched using PYSPARK_PYTHON.
#
@@ -134,7 +132,5 @@ if [[ "$1" =~ \.py$ ]]; then
gatherSparkSubmitOpts "$@"
exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}"
else
- # PySpark shell requires special handling downstream
- export PYSPARK_SHELL=1
exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS
fi
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index a0e66abcc26c9..a542ec80b49d6 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -59,7 +59,11 @@ for /f %%i in ('echo %1^| findstr /R "\.py"') do (
)
if [%PYTHON_FILE%] == [] (
- %PYSPARK_PYTHON%
+ if [%IPYTHON%] == [1] (
+ ipython %IPYTHON_OPTS%
+ ) else (
+ %PYSPARK_PYTHON%
+ )
) else (
echo.
echo WARNING: Running python applications through ./bin/pyspark.cmd is deprecated as of Spark 1.0.
diff --git a/bin/run-example b/bin/run-example
index 34dd71c71880e..c567acf9a6b5c 100755
--- a/bin/run-example
+++ b/bin/run-example
@@ -17,12 +17,12 @@
# limitations under the License.
#
-SCALA_VERSION=2.10
-
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
export SPARK_HOME="$FWDIR"
EXAMPLES_DIR="$FWDIR"/examples
+. "$FWDIR"/bin/load-spark-env.sh
+
if [ -n "$1" ]; then
EXAMPLE_CLASS="$1"
shift
@@ -35,17 +35,32 @@ else
fi
if [ -f "$FWDIR/RELEASE" ]; then
- export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`"
-elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then
- export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`"
+ JAR_PATH="${FWDIR}/lib"
+else
+ JAR_PATH="${EXAMPLES_DIR}/target/scala-${SPARK_SCALA_VERSION}"
fi
-if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then
- echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2
- echo "You need to build Spark before running this program" 1>&2
+JAR_COUNT=0
+
+for f in ${JAR_PATH}/spark-examples-*hadoop*.jar; do
+ if [[ ! -e "$f" ]]; then
+ echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2
+ echo "You need to build Spark before running this program" 1>&2
+ exit 1
+ fi
+ SPARK_EXAMPLES_JAR="$f"
+ JAR_COUNT=$((JAR_COUNT+1))
+done
+
+if [ "$JAR_COUNT" -gt "1" ]; then
+ echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2
+ ls ${JAR_PATH}/spark-examples-*hadoop*.jar 1>&2
+ echo "Please remove all but one jar." 1>&2
exit 1
fi
+export SPARK_EXAMPLES_JAR
+
EXAMPLE_MASTER=${MASTER:-"local[*]"}
if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then
diff --git a/bin/spark-class b/bin/spark-class
index 91d858bc063d0..2f0441bb3c1c2 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -24,13 +24,12 @@ case "`uname`" in
CYGWIN*) cygwin=true;;
esac
-SCALA_VERSION=2.10
-
# Figure out where Spark is installed
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
+export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}"
. "$FWDIR"/bin/load-spark-env.sh
@@ -73,6 +72,8 @@ case "$1" in
'org.apache.spark.executor.MesosExecutorBackend')
OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM}
+ export PYTHONPATH="$FWDIR/python:$PYTHONPATH"
+ export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
;;
# Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS +
@@ -81,7 +82,11 @@ case "$1" in
OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS"
OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM}
if [ -n "$SPARK_SUBMIT_LIBRARY_PATH" ]; then
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS -Djava.library.path=$SPARK_SUBMIT_LIBRARY_PATH"
+ if [[ $OSTYPE == darwin* ]]; then
+ export DYLD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$DYLD_LIBRARY_PATH"
+ else
+ export LD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$LD_LIBRARY_PATH"
+ fi
fi
if [ -n "$SPARK_SUBMIT_DRIVER_MEMORY" ]; then
OUR_JAVA_MEM="$SPARK_SUBMIT_DRIVER_MEMORY"
@@ -116,17 +121,17 @@ fi
JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
-if [ -e "$FWDIR/conf/java-opts" ] ; then
- JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`"
+if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then
+ JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`"
fi
# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala!
TOOLS_DIR="$FWDIR"/tools
SPARK_TOOLS_JAR=""
-if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then
+if [ -e "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then
# Use the JAR from the SBT build
- export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar`"
+ export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar`"
fi
if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then
# Use the JAR from the Maven build
@@ -145,8 +150,8 @@ fi
if [[ "$1" =~ org.apache.spark.tools.* ]]; then
if test -z "$SPARK_TOOLS_JAR"; then
- echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2
- echo "You need to build Spark before running $1." 1>&2
+ echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2
+ echo "You need to run \"build/sbt tools/package\" before running $1." 1>&2
exit 1
fi
CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR"
diff --git a/bin/spark-shell b/bin/spark-shell
index 4a0670fc6c8aa..cca5aa0676123 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -45,6 +45,13 @@ source "$FWDIR"/bin/utils.sh
SUBMIT_USAGE_FUNCTION=usage
gatherSparkSubmitOpts "$@"
+# SPARK-4161: scala does not assume use of the java classpath,
+# so we need to add the "-Dscala.usejavacp=true" flag mnually. We
+# do this specifically for the Spark shell because the scala REPL
+# has its own class loader, and any additional classpath specified
+# through spark.driver.extraClassPath is not automatically propagated.
+SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Dscala.usejavacp=true"
+
function main() {
if $cygwin; then
# Workaround for issue involving JLine and Cygwin
diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd
index 2ee60b4e2a2b3..1d1a40da315eb 100644
--- a/bin/spark-shell2.cmd
+++ b/bin/spark-shell2.cmd
@@ -19,4 +19,23 @@ rem
set SPARK_HOME=%~dp0..
-cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell
+echo "%*" | findstr " --help -h" >nul
+if %ERRORLEVEL% equ 0 (
+ call :usage
+ exit /b 0
+)
+
+call %SPARK_HOME%\bin\windows-utils.cmd %*
+if %ERRORLEVEL% equ 1 (
+ call :usage
+ exit /b 1
+)
+
+cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %SUBMISSION_OPTS% spark-shell %APPLICATION_OPTS%
+
+exit /b 0
+
+:usage
+echo "Usage: .\bin\spark-shell.cmd [options]" >&2
+%SPARK_HOME%\bin\spark-submit --help 2>&1 | findstr /V "Usage" 1>&2
+exit /b 0
diff --git a/bin/spark-sql b/bin/spark-sql
index 63d00437d508d..3b6cc420fea81 100755
--- a/bin/spark-sql
+++ b/bin/spark-sql
@@ -23,6 +23,8 @@
# Enter posix mode for bash
set -o posix
+# NOTE: This exact class name is matched downstream by SparkSubmit.
+# Any changes need to be reflected there.
CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
# Figure out where Spark is installed
diff --git a/bin/spark-submit b/bin/spark-submit
index c557311b4b20e..3e5cbdbb24394 100755
--- a/bin/spark-submit
+++ b/bin/spark-submit
@@ -22,6 +22,9 @@
export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
ORIG_ARGS=("$@")
+# Set COLUMNS for progress bar
+export COLUMNS=`tput cols`
+
while (($#)); do
if [ "$1" = "--deploy-mode" ]; then
SPARK_SUBMIT_DEPLOY_MODE=$2
@@ -35,11 +38,19 @@ while (($#)); do
export SPARK_SUBMIT_CLASSPATH=$2
elif [ "$1" = "--driver-java-options" ]; then
export SPARK_SUBMIT_OPTS=$2
+ elif [ "$1" = "--master" ]; then
+ export MASTER=$2
fi
shift
done
-DEFAULT_PROPERTIES_FILE="$SPARK_HOME/conf/spark-defaults.conf"
+if [ -z "$SPARK_CONF_DIR" ]; then
+ export SPARK_CONF_DIR="$SPARK_HOME/conf"
+fi
+DEFAULT_PROPERTIES_FILE="$SPARK_CONF_DIR/spark-defaults.conf"
+if [ "$MASTER" == "yarn-cluster" ]; then
+ SPARK_SUBMIT_DEPLOY_MODE=cluster
+fi
export SPARK_SUBMIT_DEPLOY_MODE=${SPARK_SUBMIT_DEPLOY_MODE:-"client"}
export SPARK_SUBMIT_PROPERTIES_FILE=${SPARK_SUBMIT_PROPERTIES_FILE:-"$DEFAULT_PROPERTIES_FILE"}
diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd
index cf6046d1547ad..12244a9cb04fb 100644
--- a/bin/spark-submit2.cmd
+++ b/bin/spark-submit2.cmd
@@ -24,7 +24,11 @@ set ORIG_ARGS=%*
rem Reset the values of all variables used
set SPARK_SUBMIT_DEPLOY_MODE=client
-set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
+
+if not defined %SPARK_CONF_DIR% (
+ set SPARK_CONF_DIR=%SPARK_HOME%\conf
+)
+set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf
set SPARK_SUBMIT_DRIVER_MEMORY=
set SPARK_SUBMIT_LIBRARY_PATH=
set SPARK_SUBMIT_CLASSPATH=
@@ -45,11 +49,17 @@ if [%1] == [] goto continue
set SPARK_SUBMIT_CLASSPATH=%2
) else if [%1] == [--driver-java-options] (
set SPARK_SUBMIT_OPTS=%2
+ ) else if [%1] == [--master] (
+ set MASTER=%2
)
shift
goto loop
:continue
+if [%MASTER%] == [yarn-cluster] (
+ set SPARK_SUBMIT_DEPLOY_MODE=cluster
+)
+
rem For client mode, the driver will be launched in the same JVM that launches
rem SparkSubmit, so we may need to read the properties file for any extra class
rem paths, library paths, java options and memory early on. Otherwise, it will
diff --git a/bin/utils.sh b/bin/utils.sh
index 22ea2b9a6d586..2241200082018 100755
--- a/bin/utils.sh
+++ b/bin/utils.sh
@@ -26,14 +26,14 @@ function gatherSparkSubmitOpts() {
exit 1
fi
- # NOTE: If you add or remove spark-sumbmit options,
+ # NOTE: If you add or remove spark-submit options,
# modify NOT ONLY this script but also SparkSubmitArgument.scala
SUBMISSION_OPTS=()
APPLICATION_OPTS=()
while (($#)); do
case "$1" in
- --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \
- --conf | --properties-file | --driver-memory | --driver-java-options | \
+ --master | --deploy-mode | --class | --name | --jars | --packages | --py-files | --files | \
+ --conf | --repositories | --properties-file | --driver-memory | --driver-java-options | \
--driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \
--total-executor-cores | --executor-cores | --queue | --num-executors | --archives)
if [[ $# -lt 2 ]]; then
diff --git a/bin/windows-utils.cmd b/bin/windows-utils.cmd
new file mode 100644
index 0000000000000..567b8733f7f77
--- /dev/null
+++ b/bin/windows-utils.cmd
@@ -0,0 +1,59 @@
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem Gather all spark-submit options into SUBMISSION_OPTS
+
+set SUBMISSION_OPTS=
+set APPLICATION_OPTS=
+
+rem NOTE: If you add or remove spark-sumbmit options,
+rem modify NOT ONLY this script but also SparkSubmitArgument.scala
+
+:OptsLoop
+if "x%1"=="x" (
+ goto :OptsLoopEnd
+)
+
+SET opts="\<--master\> \<--deploy-mode\> \<--class\> \<--name\> \<--jars\> \<--py-files\> \<--files\>"
+SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--driver-java-options\>"
+SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>"
+SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>"
+SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\> \<--packages\> \<--repositories\>"
+
+echo %1 | findstr %opts% >nul
+if %ERRORLEVEL% equ 0 (
+ if "x%2"=="x" (
+ echo "%1" requires an argument. >&2
+ exit /b 1
+ )
+ set SUBMISSION_OPTS=%SUBMISSION_OPTS% %1 %2
+ shift
+ shift
+ goto :OptsLoop
+)
+echo %1 | findstr "\<--verbose\> \<-v\> \<--supervise\>" >nul
+if %ERRORLEVEL% equ 0 (
+ set SUBMISSION_OPTS=%SUBMISSION_OPTS% %1
+ shift
+ goto :OptsLoop
+)
+set APPLICATION_OPTS=%APPLICATION_OPTS% %1
+shift
+goto :OptsLoop
+
+:OptsLoopEnd
+exit /b 0
diff --git a/build/mvn b/build/mvn
new file mode 100755
index 0000000000000..a87c5a26230c8
--- /dev/null
+++ b/build/mvn
@@ -0,0 +1,149 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Determine the current working directory
+_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+# Preserve the calling directory
+_CALLING_DIR="$(pwd)"
+
+# Installs any application tarball given a URL, the expected tarball name,
+# and, optionally, a checkable binary path to determine if the binary has
+# already been installed
+## Arg1 - URL
+## Arg2 - Tarball Name
+## Arg3 - Checkable Binary
+install_app() {
+ local remote_tarball="$1/$2"
+ local local_tarball="${_DIR}/$2"
+ local binary="${_DIR}/$3"
+
+ # setup `curl` and `wget` silent options if we're running on Jenkins
+ local curl_opts=""
+ local wget_opts=""
+ if [ -n "$AMPLAB_JENKINS" ]; then
+ curl_opts="-s"
+ wget_opts="--quiet"
+ else
+ curl_opts="--progress-bar"
+ wget_opts="--progress=bar:force"
+ fi
+
+ if [ -z "$3" -o ! -f "$binary" ]; then
+ # check if we already have the tarball
+ # check if we have curl installed
+ # download application
+ [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \
+ echo "exec: curl ${curl_opts} ${remote_tarball}" && \
+ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}"
+ # if the file still doesn't exist, lets try `wget` and cross our fingers
+ [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \
+ echo "exec: wget ${wget_opts} ${remote_tarball}" && \
+ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}"
+ # if both were unsuccessful, exit
+ [ ! -f "${local_tarball}" ] && \
+ echo -n "ERROR: Cannot download $2 with cURL or wget; " && \
+ echo "please install manually and try again." && \
+ exit 2
+ cd "${_DIR}" && tar -xzf "$2"
+ rm -rf "$local_tarball"
+ fi
+}
+
+# Install maven under the build/ folder
+install_mvn() {
+ install_app \
+ "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \
+ "apache-maven-3.2.5-bin.tar.gz" \
+ "apache-maven-3.2.5/bin/mvn"
+ MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn"
+}
+
+# Install zinc under the build/ folder
+install_zinc() {
+ local zinc_path="zinc-0.3.5.3/bin/zinc"
+ [ ! -f "${zinc_path}" ] && ZINC_INSTALL_FLAG=1
+ install_app \
+ "http://downloads.typesafe.com/zinc/0.3.5.3" \
+ "zinc-0.3.5.3.tgz" \
+ "${zinc_path}"
+ ZINC_BIN="${_DIR}/${zinc_path}"
+}
+
+# Determine the Scala version from the root pom.xml file, set the Scala URL,
+# and, with that, download the specific version of Scala necessary under
+# the build/ folder
+install_scala() {
+ # determine the Scala version used in Spark
+ local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | \
+ head -1 | cut -f2 -d'>' | cut -f1 -d'<'`
+ local scala_bin="${_DIR}/scala-${scala_version}/bin/scala"
+
+ install_app \
+ "http://downloads.typesafe.com/scala/${scala_version}" \
+ "scala-${scala_version}.tgz" \
+ "scala-${scala_version}/bin/scala"
+
+ SCALA_COMPILER="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-compiler.jar"
+ SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar"
+}
+
+# Determines if a given application is already installed. If not, will attempt
+# to install
+## Arg1 - application name
+## Arg2 - Alternate path to local install under build/ dir
+check_and_install_app() {
+ # create the local environment variable in uppercase
+ local app_bin="`echo $1 | awk '{print toupper(\$0)}'`_BIN"
+ # some black magic to set the generated app variable (i.e. MVN_BIN) into the
+ # environment
+ eval "${app_bin}=`which $1 2>/dev/null`"
+
+ if [ -z "`which $1 2>/dev/null`" ]; then
+ install_$1
+ fi
+}
+
+# Setup healthy defaults for the Zinc port if none were provided from
+# the environment
+ZINC_PORT=${ZINC_PORT:-"3030"}
+
+# Check and install all applications necessary to build Spark
+check_and_install_app "mvn"
+
+# Install the proper version of Scala and Zinc for the build
+install_zinc
+install_scala
+
+# Reset the current working directory
+cd "${_CALLING_DIR}"
+
+# Now that zinc is ensured to be installed, check its status and, if its
+# not running or just installed, start it
+if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then
+ ${ZINC_BIN} -shutdown
+ ${ZINC_BIN} -start -port ${ZINC_PORT} \
+ -scala-compiler "${SCALA_COMPILER}" \
+ -scala-library "${SCALA_LIBRARY}" &>/dev/null
+fi
+
+# Set any `mvn` options if not already present
+export MAVEN_OPTS=${MAVEN_OPTS:-"-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"}
+
+# Last, call the `mvn` command as usual
+${MVN_BIN} "$@"
diff --git a/build/sbt b/build/sbt
new file mode 100755
index 0000000000000..28ebb64f7197c
--- /dev/null
+++ b/build/sbt
@@ -0,0 +1,128 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so
+# that we can run Hive to generate the golden answer. This is not required for normal development
+# or testing.
+for i in "$HIVE_HOME"/lib/*
+do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i"
+done
+export HADOOP_CLASSPATH
+
+realpath () {
+(
+ TARGET_FILE="$1"
+
+ cd "$(dirname "$TARGET_FILE")"
+ TARGET_FILE="$(basename "$TARGET_FILE")"
+
+ COUNT=0
+ while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ]
+ do
+ TARGET_FILE="$(readlink "$TARGET_FILE")"
+ cd $(dirname "$TARGET_FILE")
+ TARGET_FILE="$(basename $TARGET_FILE)"
+ COUNT=$(($COUNT + 1))
+ done
+
+ echo "$(pwd -P)/"$TARGET_FILE""
+)
+}
+
+. "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash
+
+
+declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy"
+declare -r sbt_opts_file=".sbtopts"
+declare -r etc_sbt_opts_file="/etc/sbt/sbtopts"
+
+usage() {
+ cat < path to global settings/plugins directory (default: ~/.sbt)
+ -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series)
+ -ivy path to local Ivy repository (default: ~/.ivy2)
+ -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem))
+ -no-share use all local caches; no sharing
+ -no-global uses global caches, but does not use global ~/.sbt directory.
+ -jvm-debug Turn on JVM debugging, open at the given port.
+ -batch Disable interactive mode
+
+ # sbt version (default: from project/build.properties if present, else latest release)
+ -sbt-version use the specified version of sbt
+ -sbt-jar use the specified jar as the sbt launcher
+ -sbt-rc use an RC version of sbt
+ -sbt-snapshot use a snapshot version of sbt
+
+ # java version (default: java from PATH, currently $(java -version 2>&1 | grep version))
+ -java-home alternate JAVA_HOME
+
+ # jvm options and output control
+ JAVA_OPTS environment variable, if unset uses "$java_opts"
+ SBT_OPTS environment variable, if unset uses "$default_sbt_opts"
+ .sbtopts if this file exists in the current directory, it is
+ prepended to the runner args
+ /etc/sbt/sbtopts if this file exists, it is prepended to the runner args
+ -Dkey=val pass -Dkey=val directly to the java runtime
+ -J-X pass option -X directly to the java runtime
+ (-J is stripped)
+ -S-X add -X to sbt's scalacOptions (-S is stripped)
+ -PmavenProfiles Enable a maven profile for the build.
+
+In the case of duplicated or conflicting options, the order above
+shows precedence: JAVA_OPTS lowest, command line options highest.
+EOM
+}
+
+process_my_args () {
+ while [[ $# -gt 0 ]]; do
+ case "$1" in
+ -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;;
+ -no-share) addJava "$noshare_opts" && shift ;;
+ -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;;
+ -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;;
+ -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;;
+ -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;;
+ -batch) exec /dev/null; then
+ if [ $(command -v curl) ]; then
(curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
- elif hash wget 2>/dev/null; then
+ elif [ $(command -v wget) ]; then
(wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
else
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
@@ -104,7 +104,7 @@ addResidual () {
residual_args=( "${residual_args[@]}" "$1" )
}
addDebugger () {
- addJava "-Xdebug -Xrunjdwp:transport=dt_socket,server=y,suspend=n,address=$1"
+ addJava "-agentlib:jdwp=transport=dt_socket,server=y,suspend=n,address=$1"
}
# a ham-fisted attempt to move some memory settings in concert
@@ -124,7 +124,8 @@ require_arg () {
local opt="$2"
local arg="$3"
if [[ -z "$arg" ]] || [[ "${arg:0:1}" == "-" ]]; then
- die "$opt requires <$type> argument"
+ echo "$opt requires <$type> argument" 1>&2
+ exit 1
fi
}
@@ -149,7 +150,7 @@ process_args () {
-java-home) require_arg path "$1" "$2" && java_cmd="$2/bin/java" && export JAVA_HOME=$2 && shift 2 ;;
-D*) addJava "$1" && shift ;;
- -J*) addJava "${1:2}" && shift ;;
+ -J*) addJava "${1:2}" && shift ;;
-P*) enableProfile "$1" && shift ;;
*) addResidual "$1" && shift ;;
esac
@@ -185,10 +186,3 @@ run() {
"${sbt_commands[@]}" \
"${residual_args[@]}"
}
-
-runAlternateBoot() {
- local bootpropsfile="$1"
- shift
- addJava "-Dsbt.boot.properties=$bootpropsfile"
- run $@
-}
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index 30bcab0c93302..464c14457e53f 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -77,8 +77,8 @@
# sample false Whether to show entire set of samples for histograms ('false' or 'true')
#
# * Default path is /metrics/json for all instances except the master. The master has two paths:
-# /metrics/aplications/json # App information
-# /metrics/master/json # Master information
+# /metrics/applications/json # App information
+# /metrics/master/json # Master information
# org.apache.spark.metrics.sink.GraphiteSink
# Name: Default: Description:
@@ -87,6 +87,7 @@
# period 10 Poll period
# unit seconds Units of poll period
# prefix EMPTY STRING Prefix to prepend to metric name
+# protocol tcp Protocol ("tcp" or "udp") to use
## Examples
# Enable JmxSink for all instances by class name
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index f8ffbf64278fb..0886b0276fb90 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -28,7 +28,7 @@
# - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job.
# - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job.
-# Options for the daemons used in the standalone deploy mode:
+# Options for the daemons used in the standalone deploy mode
# - SPARK_MASTER_IP, to bind the master to a different IP address or hostname
# - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports for the master
# - SPARK_MASTER_OPTS, to set config properties only for the master (e.g. "-Dx=y")
@@ -41,3 +41,10 @@
# - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y")
# - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y")
# - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers
+
+# Generic options for the daemons used in the standalone deploy mode
+# - SPARK_CONF_DIR Alternate conf dir. (Default: ${SPARK_HOME}/conf)
+# - SPARK_LOG_DIR Where log files are stored. (Default: ${SPARK_HOME}/logs)
+# - SPARK_PID_DIR Where the pid file is stored. (Default: /tmp)
+# - SPARK_IDENT_STRING A string representing this instance of spark. (Default: $USER)
+# - SPARK_NICENESS The scheduling priority for daemons. (Default: 0)
diff --git a/core/pom.xml b/core/pom.xml
index a5a178079bc57..2dc5f747f2b71 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
@@ -34,6 +34,38 @@
Spark Project Corehttp://spark.apache.org/
+
+ com.google.guava
+ guava
+
+
+ com.twitter
+ chill_${scala.binary.version}
+
+
+ org.ow2.asm
+ asm
+
+
+ org.ow2.asm
+ asm-commons
+
+
+
+
+ com.twitter
+ chill-java
+
+
+ org.ow2.asm
+ asm
+
+
+ org.ow2.asm
+ asm-commons
+
+
+ org.apache.hadoophadoop-client
@@ -44,6 +76,16 @@
+
+ org.apache.spark
+ spark-network-common_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-network-shuffle_${scala.binary.version}
+ ${project.version}
+ net.java.dev.jets3tjets3t
@@ -52,32 +94,45 @@
org.apache.curatorcurator-recipes
+
+
org.eclipse.jettyjetty-plus
+ compileorg.eclipse.jettyjetty-security
+ compileorg.eclipse.jettyjetty-util
+ compileorg.eclipse.jettyjetty-server
+ compile
-
- com.google.guava
- guava
+ org.eclipse.jetty
+ jetty-http
+ compile
+
+
+ org.eclipse.jetty
+ jetty-continuation
+ compile
+
+
+ org.eclipse.jetty
+ jetty-servletcompile
+
org.apache.commonscommons-lang3
@@ -85,8 +140,6 @@
org.apache.commonscommons-math3
- 3.3
- testcom.google.code.findbugs
@@ -125,12 +178,8 @@
lz4
- com.twitter
- chill_${scala.binary.version}
-
-
- com.twitter
- chill-java
+ org.roaringbitmap
+ RoaringBitmapcommons-net
@@ -158,10 +207,6 @@
json4s-jackson_${scala.binary.version}3.2.10
-
- colt
- colt
- org.apache.mesosmesos
@@ -176,19 +221,19 @@
stream
- com.codahale.metrics
+ io.dropwizard.metricsmetrics-core
- com.codahale.metrics
+ io.dropwizard.metricsmetrics-jvm
- com.codahale.metrics
+ io.dropwizard.metricsmetrics-json
- com.codahale.metrics
+ io.dropwizard.metricsmetrics-graphite
@@ -196,6 +241,17 @@
derbytest
+
+ org.apache.ivy
+ ivy
+ ${ivy.version}
+
+
+ oro
+
+ oro
+ ${oro.version}
+ org.tachyonprojecttachyon-client
@@ -244,8 +300,8 @@
- org.scalatest
- scalatest_${scala.binary.version}
+ org.seleniumhq.selenium
+ selenium-javatest
@@ -293,17 +349,6 @@
target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
-
- org.scalatest
- scalatest-maven-plugin
-
-
- ${basedir}/..
- 1
- ${spark.classpath}
-
-
- org.apache.maven.plugins
@@ -317,9 +362,9 @@
-
+
-
+
@@ -333,59 +378,28 @@
true
-
- org.apache.maven.plugins
- maven-shade-plugin
-
-
- package
-
- shade
-
-
- false
-
-
- com.google.guava:guava
-
-
-
-
-
- com.google.guava:guava
-
- com/google/common/base/Absent*
- com/google/common/base/Optional*
- com/google/common/base/Present*
-
-
-
-
-
-
-
-
org.apache.maven.pluginsmaven-dependency-plugin
+
copy-dependenciespackagecopy-dependencies
-
+ ${project.build.directory}falsefalsetruetrue
- guava
+
+ guava,jetty-io,jetty-servlet,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server
+ true
@@ -411,4 +425,5 @@
+
diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java
new file mode 100644
index 0000000000000..646496f313507
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import org.apache.spark.scheduler.SparkListener;
+import org.apache.spark.scheduler.SparkListenerApplicationEnd;
+import org.apache.spark.scheduler.SparkListenerApplicationStart;
+import org.apache.spark.scheduler.SparkListenerBlockManagerAdded;
+import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved;
+import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate;
+import org.apache.spark.scheduler.SparkListenerExecutorAdded;
+import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate;
+import org.apache.spark.scheduler.SparkListenerExecutorRemoved;
+import org.apache.spark.scheduler.SparkListenerJobEnd;
+import org.apache.spark.scheduler.SparkListenerJobStart;
+import org.apache.spark.scheduler.SparkListenerStageCompleted;
+import org.apache.spark.scheduler.SparkListenerStageSubmitted;
+import org.apache.spark.scheduler.SparkListenerTaskEnd;
+import org.apache.spark.scheduler.SparkListenerTaskGettingResult;
+import org.apache.spark.scheduler.SparkListenerTaskStart;
+import org.apache.spark.scheduler.SparkListenerUnpersistRDD;
+
+/**
+ * Java clients should extend this class instead of implementing
+ * SparkListener directly. This is to prevent java clients
+ * from breaking when new events are added to the SparkListener
+ * trait.
+ *
+ * This is a concrete class instead of abstract to enforce
+ * new events get added to both the SparkListener and this adapter
+ * in lockstep.
+ */
+public class JavaSparkListener implements SparkListener {
+
+ @Override
+ public void onStageCompleted(SparkListenerStageCompleted stageCompleted) { }
+
+ @Override
+ public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { }
+
+ @Override
+ public void onTaskStart(SparkListenerTaskStart taskStart) { }
+
+ @Override
+ public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { }
+
+ @Override
+ public void onTaskEnd(SparkListenerTaskEnd taskEnd) { }
+
+ @Override
+ public void onJobStart(SparkListenerJobStart jobStart) { }
+
+ @Override
+ public void onJobEnd(SparkListenerJobEnd jobEnd) { }
+
+ @Override
+ public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { }
+
+ @Override
+ public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { }
+
+ @Override
+ public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { }
+
+ @Override
+ public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { }
+
+ @Override
+ public void onApplicationStart(SparkListenerApplicationStart applicationStart) { }
+
+ @Override
+ public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { }
+
+ @Override
+ public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { }
+
+ @Override
+ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { }
+
+ @Override
+ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { }
+}
diff --git a/core/src/main/java/org/apache/spark/JobExecutionStatus.java b/core/src/main/java/org/apache/spark/JobExecutionStatus.java
new file mode 100644
index 0000000000000..6e161313702bb
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/JobExecutionStatus.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+public enum JobExecutionStatus {
+ RUNNING,
+ SUCCEEDED,
+ FAILED,
+ UNKNOWN
+}
diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
new file mode 100644
index 0000000000000..fbc5666959055
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import org.apache.spark.scheduler.*;
+
+/**
+ * Class that allows users to receive all SparkListener events.
+ * Users should override the onEvent method.
+ *
+ * This is a concrete Java class in order to ensure that we don't forget to update it when adding
+ * new methods to SparkListener: forgetting to add a method will result in a compilation error (if
+ * this was a concrete Scala class, default implementations of new event handlers would be inherited
+ * from the SparkListener trait).
+ */
+public class SparkFirehoseListener implements SparkListener {
+
+ public void onEvent(SparkListenerEvent event) { }
+
+ @Override
+ public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) {
+ onEvent(stageCompleted);
+ }
+
+ @Override
+ public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {
+ onEvent(stageSubmitted);
+ }
+
+ @Override
+ public final void onTaskStart(SparkListenerTaskStart taskStart) {
+ onEvent(taskStart);
+ }
+
+ @Override
+ public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) {
+ onEvent(taskGettingResult);
+ }
+
+ @Override
+ public final void onTaskEnd(SparkListenerTaskEnd taskEnd) {
+ onEvent(taskEnd);
+ }
+
+ @Override
+ public final void onJobStart(SparkListenerJobStart jobStart) {
+ onEvent(jobStart);
+ }
+
+ @Override
+ public final void onJobEnd(SparkListenerJobEnd jobEnd) {
+ onEvent(jobEnd);
+ }
+
+ @Override
+ public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) {
+ onEvent(environmentUpdate);
+ }
+
+ @Override
+ public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) {
+ onEvent(blockManagerAdded);
+ }
+
+ @Override
+ public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) {
+ onEvent(blockManagerRemoved);
+ }
+
+ @Override
+ public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) {
+ onEvent(unpersistRDD);
+ }
+
+ @Override
+ public final void onApplicationStart(SparkListenerApplicationStart applicationStart) {
+ onEvent(applicationStart);
+ }
+
+ @Override
+ public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) {
+ onEvent(applicationEnd);
+ }
+
+ @Override
+ public final void onExecutorMetricsUpdate(
+ SparkListenerExecutorMetricsUpdate executorMetricsUpdate) {
+ onEvent(executorMetricsUpdate);
+ }
+
+ @Override
+ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) {
+ onEvent(executorAdded);
+ }
+
+ @Override
+ public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) {
+ onEvent(executorRemoved);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/SparkJobInfo.java b/core/src/main/java/org/apache/spark/SparkJobInfo.java
new file mode 100644
index 0000000000000..e31c4401632a6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/SparkJobInfo.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.Serializable;
+
+/**
+ * Exposes information about Spark Jobs.
+ *
+ * This interface is not designed to be implemented outside of Spark. We may add additional methods
+ * which may break binary compatibility with outside implementations.
+ */
+public interface SparkJobInfo extends Serializable {
+ int jobId();
+ int[] stageIds();
+ JobExecutionStatus status();
+}
diff --git a/core/src/main/java/org/apache/spark/SparkStageInfo.java b/core/src/main/java/org/apache/spark/SparkStageInfo.java
new file mode 100644
index 0000000000000..b7d462abd72d6
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/SparkStageInfo.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import java.io.Serializable;
+
+/**
+ * Exposes information about Spark Stages.
+ *
+ * This interface is not designed to be implemented outside of Spark. We may add additional methods
+ * which may break binary compatibility with outside implementations.
+ */
+public interface SparkStageInfo extends Serializable {
+ int stageId();
+ int currentAttemptId();
+ long submissionTime();
+ String name();
+ int numTasks();
+ int numActiveTasks();
+ int numCompletedTasks();
+ int numFailedTasks();
+}
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
deleted file mode 100644
index 2d998d4c7a5d9..0000000000000
--- a/core/src/main/java/org/apache/spark/TaskContext.java
+++ /dev/null
@@ -1,108 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark;
-
-import java.io.Serializable;
-
-import scala.Function0;
-import scala.Function1;
-import scala.Unit;
-
-import org.apache.spark.annotation.DeveloperApi;
-import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.util.TaskCompletionListener;
-
-/**
- * Contextual information about a task which can be read or mutated during
- * execution. To access the TaskContext for a running task use
- * TaskContext.get().
- */
-public abstract class TaskContext implements Serializable {
- /**
- * Return the currently active TaskContext. This can be called inside of
- * user functions to access contextual information about running tasks.
- */
- public static TaskContext get() {
- return taskContext.get();
- }
-
- private static ThreadLocal taskContext =
- new ThreadLocal();
-
- static void setTaskContext(TaskContext tc) {
- taskContext.set(tc);
- }
-
- static void unset() {
- taskContext.remove();
- }
-
- /**
- * Whether the task has completed.
- */
- public abstract boolean isCompleted();
-
- /**
- * Whether the task has been killed.
- */
- public abstract boolean isInterrupted();
-
- /** @deprecated: use isRunningLocally() */
- @Deprecated
- public abstract boolean runningLocally();
-
- public abstract boolean isRunningLocally();
-
- /**
- * Add a (Java friendly) listener to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
- *
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener);
-
- /**
- * Add a listener in the form of a Scala closure to be executed on task completion.
- * This will be called in all situations - success, failure, or cancellation.
- *
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- public abstract TaskContext addTaskCompletionListener(final Function1 f);
-
- /**
- * Add a callback function to be executed on task completion. An example use
- * is for HadoopRDD to register a callback to close the input stream.
- * Will be called in any situation - success, failure, or cancellation.
- *
- * @deprecated: use addTaskCompletionListener
- *
- * @param f Callback function.
- */
- @Deprecated
- public abstract void addOnCompleteCallback(final Function0 f);
-
- public abstract int stageId();
-
- public abstract int partitionId();
-
- public abstract long attemptId();
-
- /** ::DeveloperApi:: */
- @DeveloperApi
- public abstract TaskMetrics taskMetrics();
-}
diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
index abd9bcc07ac61..99bf240a17225 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
@@ -22,7 +22,8 @@
import scala.Tuple2;
/**
- * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs.
+ * A function that returns key-value pairs (Tuple2<K, V>), and can be used to
+ * construct PairRDDs.
*/
public interface PairFunction extends Serializable {
public Tuple2 call(T t) throws Exception;
diff --git a/core/src/main/java/org/apache/spark/api/java/function/package.scala b/core/src/main/java/org/apache/spark/api/java/function/package.scala
index 7f91de653a64a..0f9bac7164162 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/package.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/package.scala
@@ -22,4 +22,4 @@ package org.apache.spark.api.java
* these interfaces to pass functions to various Java API methods for Spark. Please visit Spark's
* Java programming guide for more details.
*/
-package object function
\ No newline at end of file
+package object function
diff --git a/core/src/main/java/org/apache/spark/util/collection/Sorter.java b/core/src/main/java/org/apache/spark/util/collection/Sorter.java
deleted file mode 100644
index 64ad18c0e463a..0000000000000
--- a/core/src/main/java/org/apache/spark/util/collection/Sorter.java
+++ /dev/null
@@ -1,915 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection;
-
-import java.util.Comparator;
-
-/**
- * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort."
- * See the method comment on sort() for more details.
- *
- * This has been kept in Java with the original style in order to match very closely with the
- * Anroid source code, and thus be easy to verify correctness.
- *
- * The purpose of the port is to generalize the interface to the sort to accept input data formats
- * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
- * uses this to sort an Array with alternating elements of the form [key, value, key, value].
- * This generalization comes with minimal overhead -- see SortDataFormat for more information.
- */
-class Sorter {
-
- /**
- * This is the minimum sized sequence that will be merged. Shorter
- * sequences will be lengthened by calling binarySort. If the entire
- * array is less than this length, no merges will be performed.
- *
- * This constant should be a power of two. It was 64 in Tim Peter's C
- * implementation, but 32 was empirically determined to work better in
- * this implementation. In the unlikely event that you set this constant
- * to be a number that's not a power of two, you'll need to change the
- * minRunLength computation.
- *
- * If you decrease this constant, you must change the stackLen
- * computation in the TimSort constructor, or you risk an
- * ArrayOutOfBounds exception. See listsort.txt for a discussion
- * of the minimum stack length required as a function of the length
- * of the array being sorted and the minimum merge sequence length.
- */
- private static final int MIN_MERGE = 32;
-
- private final SortDataFormat s;
-
- public Sorter(SortDataFormat sortDataFormat) {
- this.s = sortDataFormat;
- }
-
- /**
- * A stable, adaptive, iterative mergesort that requires far fewer than
- * n lg(n) comparisons when running on partially sorted arrays, while
- * offering performance comparable to a traditional mergesort when run
- * on random arrays. Like all proper mergesorts, this sort is stable and
- * runs O(n log n) time (worst case). In the worst case, this sort requires
- * temporary storage space for n/2 object references; in the best case,
- * it requires only a small constant amount of space.
- *
- * This implementation was adapted from Tim Peters's list sort for
- * Python, which is described in detail here:
- *
- * http://svn.python.org/projects/python/trunk/Objects/listsort.txt
- *
- * Tim's C code may be found here:
- *
- * http://svn.python.org/projects/python/trunk/Objects/listobject.c
- *
- * The underlying techniques are described in this paper (and may have
- * even earlier origins):
- *
- * "Optimistic Sorting and Information Theoretic Complexity"
- * Peter McIlroy
- * SODA (Fourth Annual ACM-SIAM Symposium on Discrete Algorithms),
- * pp 467-474, Austin, Texas, 25-27 January 1993.
- *
- * While the API to this class consists solely of static methods, it is
- * (privately) instantiable; a TimSort instance holds the state of an ongoing
- * sort, assuming the input array is large enough to warrant the full-blown
- * TimSort. Small arrays are sorted in place, using a binary insertion sort.
- *
- * @author Josh Bloch
- */
- void sort(Buffer a, int lo, int hi, Comparator super K> c) {
- assert c != null;
-
- int nRemaining = hi - lo;
- if (nRemaining < 2)
- return; // Arrays of size 0 and 1 are always sorted
-
- // If array is small, do a "mini-TimSort" with no merges
- if (nRemaining < MIN_MERGE) {
- int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
- binarySort(a, lo, hi, lo + initRunLen, c);
- return;
- }
-
- /**
- * March over the array once, left to right, finding natural runs,
- * extending short natural runs to minRun elements, and merging runs
- * to maintain stack invariant.
- */
- SortState sortState = new SortState(a, c, hi - lo);
- int minRun = minRunLength(nRemaining);
- do {
- // Identify next run
- int runLen = countRunAndMakeAscending(a, lo, hi, c);
-
- // If run is short, extend to min(minRun, nRemaining)
- if (runLen < minRun) {
- int force = nRemaining <= minRun ? nRemaining : minRun;
- binarySort(a, lo, lo + force, lo + runLen, c);
- runLen = force;
- }
-
- // Push run onto pending-run stack, and maybe merge
- sortState.pushRun(lo, runLen);
- sortState.mergeCollapse();
-
- // Advance to find next run
- lo += runLen;
- nRemaining -= runLen;
- } while (nRemaining != 0);
-
- // Merge all remaining runs to complete sort
- assert lo == hi;
- sortState.mergeForceCollapse();
- assert sortState.stackSize == 1;
- }
-
- /**
- * Sorts the specified portion of the specified array using a binary
- * insertion sort. This is the best method for sorting small numbers
- * of elements. It requires O(n log n) compares, but O(n^2) data
- * movement (worst case).
- *
- * If the initial part of the specified range is already sorted,
- * this method can take advantage of it: the method assumes that the
- * elements from index {@code lo}, inclusive, to {@code start},
- * exclusive are already sorted.
- *
- * @param a the array in which a range is to be sorted
- * @param lo the index of the first element in the range to be sorted
- * @param hi the index after the last element in the range to be sorted
- * @param start the index of the first element in the range that is
- * not already known to be sorted ({@code lo <= start <= hi})
- * @param c comparator to used for the sort
- */
- @SuppressWarnings("fallthrough")
- private void binarySort(Buffer a, int lo, int hi, int start, Comparator super K> c) {
- assert lo <= start && start <= hi;
- if (start == lo)
- start++;
-
- Buffer pivotStore = s.allocate(1);
- for ( ; start < hi; start++) {
- s.copyElement(a, start, pivotStore, 0);
- K pivot = s.getKey(pivotStore, 0);
-
- // Set left (and right) to the index where a[start] (pivot) belongs
- int left = lo;
- int right = start;
- assert left <= right;
- /*
- * Invariants:
- * pivot >= all in [lo, left).
- * pivot < all in [right, start).
- */
- while (left < right) {
- int mid = (left + right) >>> 1;
- if (c.compare(pivot, s.getKey(a, mid)) < 0)
- right = mid;
- else
- left = mid + 1;
- }
- assert left == right;
-
- /*
- * The invariants still hold: pivot >= all in [lo, left) and
- * pivot < all in [left, start), so pivot belongs at left. Note
- * that if there are elements equal to pivot, left points to the
- * first slot after them -- that's why this sort is stable.
- * Slide elements over to make room for pivot.
- */
- int n = start - left; // The number of elements to move
- // Switch is just an optimization for arraycopy in default case
- switch (n) {
- case 2: s.copyElement(a, left + 1, a, left + 2);
- case 1: s.copyElement(a, left, a, left + 1);
- break;
- default: s.copyRange(a, left, a, left + 1, n);
- }
- s.copyElement(pivotStore, 0, a, left);
- }
- }
-
- /**
- * Returns the length of the run beginning at the specified position in
- * the specified array and reverses the run if it is descending (ensuring
- * that the run will always be ascending when the method returns).
- *
- * A run is the longest ascending sequence with:
- *
- * a[lo] <= a[lo + 1] <= a[lo + 2] <= ...
- *
- * or the longest descending sequence with:
- *
- * a[lo] > a[lo + 1] > a[lo + 2] > ...
- *
- * For its intended use in a stable mergesort, the strictness of the
- * definition of "descending" is needed so that the call can safely
- * reverse a descending sequence without violating stability.
- *
- * @param a the array in which a run is to be counted and possibly reversed
- * @param lo index of the first element in the run
- * @param hi index after the last element that may be contained in the run.
- It is required that {@code lo < hi}.
- * @param c the comparator to used for the sort
- * @return the length of the run beginning at the specified position in
- * the specified array
- */
- private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator super K> c) {
- assert lo < hi;
- int runHi = lo + 1;
- if (runHi == hi)
- return 1;
-
- // Find end of run, and reverse range if descending
- if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending
- while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0)
- runHi++;
- reverseRange(a, lo, runHi);
- } else { // Ascending
- while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0)
- runHi++;
- }
-
- return runHi - lo;
- }
-
- /**
- * Reverse the specified range of the specified array.
- *
- * @param a the array in which a range is to be reversed
- * @param lo the index of the first element in the range to be reversed
- * @param hi the index after the last element in the range to be reversed
- */
- private void reverseRange(Buffer a, int lo, int hi) {
- hi--;
- while (lo < hi) {
- s.swap(a, lo, hi);
- lo++;
- hi--;
- }
- }
-
- /**
- * Returns the minimum acceptable run length for an array of the specified
- * length. Natural runs shorter than this will be extended with
- * {@link #binarySort}.
- *
- * Roughly speaking, the computation is:
- *
- * If n < MIN_MERGE, return n (it's too small to bother with fancy stuff).
- * Else if n is an exact power of 2, return MIN_MERGE/2.
- * Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k
- * is close to, but strictly less than, an exact power of 2.
- *
- * For the rationale, see listsort.txt.
- *
- * @param n the length of the array to be sorted
- * @return the length of the minimum run to be merged
- */
- private int minRunLength(int n) {
- assert n >= 0;
- int r = 0; // Becomes 1 if any 1 bits are shifted off
- while (n >= MIN_MERGE) {
- r |= (n & 1);
- n >>= 1;
- }
- return n + r;
- }
-
- private class SortState {
-
- /**
- * The Buffer being sorted.
- */
- private final Buffer a;
-
- /**
- * Length of the sort Buffer.
- */
- private final int aLength;
-
- /**
- * The comparator for this sort.
- */
- private final Comparator super K> c;
-
- /**
- * When we get into galloping mode, we stay there until both runs win less
- * often than MIN_GALLOP consecutive times.
- */
- private static final int MIN_GALLOP = 7;
-
- /**
- * This controls when we get *into* galloping mode. It is initialized
- * to MIN_GALLOP. The mergeLo and mergeHi methods nudge it higher for
- * random data, and lower for highly structured data.
- */
- private int minGallop = MIN_GALLOP;
-
- /**
- * Maximum initial size of tmp array, which is used for merging. The array
- * can grow to accommodate demand.
- *
- * Unlike Tim's original C version, we do not allocate this much storage
- * when sorting smaller arrays. This change was required for performance.
- */
- private static final int INITIAL_TMP_STORAGE_LENGTH = 256;
-
- /**
- * Temp storage for merges.
- */
- private Buffer tmp; // Actual runtime type will be Object[], regardless of T
-
- /**
- * Length of the temp storage.
- */
- private int tmpLength = 0;
-
- /**
- * A stack of pending runs yet to be merged. Run i starts at
- * address base[i] and extends for len[i] elements. It's always
- * true (so long as the indices are in bounds) that:
- *
- * runBase[i] + runLen[i] == runBase[i + 1]
- *
- * so we could cut the storage for this, but it's a minor amount,
- * and keeping all the info explicit simplifies the code.
- */
- private int stackSize = 0; // Number of pending runs on stack
- private final int[] runBase;
- private final int[] runLen;
-
- /**
- * Creates a TimSort instance to maintain the state of an ongoing sort.
- *
- * @param a the array to be sorted
- * @param c the comparator to determine the order of the sort
- */
- private SortState(Buffer a, Comparator super K> c, int len) {
- this.aLength = len;
- this.a = a;
- this.c = c;
-
- // Allocate temp storage (which may be increased later if necessary)
- tmpLength = len < 2 * INITIAL_TMP_STORAGE_LENGTH ? len >>> 1 : INITIAL_TMP_STORAGE_LENGTH;
- tmp = s.allocate(tmpLength);
-
- /*
- * Allocate runs-to-be-merged stack (which cannot be expanded). The
- * stack length requirements are described in listsort.txt. The C
- * version always uses the same stack length (85), but this was
- * measured to be too expensive when sorting "mid-sized" arrays (e.g.,
- * 100 elements) in Java. Therefore, we use smaller (but sufficiently
- * large) stack lengths for smaller arrays. The "magic numbers" in the
- * computation below must be changed if MIN_MERGE is decreased. See
- * the MIN_MERGE declaration above for more information.
- */
- int stackLen = (len < 120 ? 5 :
- len < 1542 ? 10 :
- len < 119151 ? 19 : 40);
- runBase = new int[stackLen];
- runLen = new int[stackLen];
- }
-
- /**
- * Pushes the specified run onto the pending-run stack.
- *
- * @param runBase index of the first element in the run
- * @param runLen the number of elements in the run
- */
- private void pushRun(int runBase, int runLen) {
- this.runBase[stackSize] = runBase;
- this.runLen[stackSize] = runLen;
- stackSize++;
- }
-
- /**
- * Examines the stack of runs waiting to be merged and merges adjacent runs
- * until the stack invariants are reestablished:
- *
- * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
- * 2. runLen[i - 2] > runLen[i - 1]
- *
- * This method is called each time a new run is pushed onto the stack,
- * so the invariants are guaranteed to hold for i < stackSize upon
- * entry to the method.
- */
- private void mergeCollapse() {
- while (stackSize > 1) {
- int n = stackSize - 2;
- if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
- if (runLen[n - 1] < runLen[n + 1])
- n--;
- mergeAt(n);
- } else if (runLen[n] <= runLen[n + 1]) {
- mergeAt(n);
- } else {
- break; // Invariant is established
- }
- }
- }
-
- /**
- * Merges all runs on the stack until only one remains. This method is
- * called once, to complete the sort.
- */
- private void mergeForceCollapse() {
- while (stackSize > 1) {
- int n = stackSize - 2;
- if (n > 0 && runLen[n - 1] < runLen[n + 1])
- n--;
- mergeAt(n);
- }
- }
-
- /**
- * Merges the two runs at stack indices i and i+1. Run i must be
- * the penultimate or antepenultimate run on the stack. In other words,
- * i must be equal to stackSize-2 or stackSize-3.
- *
- * @param i stack index of the first of the two runs to merge
- */
- private void mergeAt(int i) {
- assert stackSize >= 2;
- assert i >= 0;
- assert i == stackSize - 2 || i == stackSize - 3;
-
- int base1 = runBase[i];
- int len1 = runLen[i];
- int base2 = runBase[i + 1];
- int len2 = runLen[i + 1];
- assert len1 > 0 && len2 > 0;
- assert base1 + len1 == base2;
-
- /*
- * Record the length of the combined runs; if i is the 3rd-last
- * run now, also slide over the last run (which isn't involved
- * in this merge). The current run (i+1) goes away in any case.
- */
- runLen[i] = len1 + len2;
- if (i == stackSize - 3) {
- runBase[i + 1] = runBase[i + 2];
- runLen[i + 1] = runLen[i + 2];
- }
- stackSize--;
-
- /*
- * Find where the first element of run2 goes in run1. Prior elements
- * in run1 can be ignored (because they're already in place).
- */
- int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c);
- assert k >= 0;
- base1 += k;
- len1 -= k;
- if (len1 == 0)
- return;
-
- /*
- * Find where the last element of run1 goes in run2. Subsequent elements
- * in run2 can be ignored (because they're already in place).
- */
- len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c);
- assert len2 >= 0;
- if (len2 == 0)
- return;
-
- // Merge remaining runs, using tmp array with min(len1, len2) elements
- if (len1 <= len2)
- mergeLo(base1, len1, base2, len2);
- else
- mergeHi(base1, len1, base2, len2);
- }
-
- /**
- * Locates the position at which to insert the specified key into the
- * specified sorted range; if the range contains an element equal to key,
- * returns the index of the leftmost equal element.
- *
- * @param key the key whose insertion point to search for
- * @param a the array in which to search
- * @param base the index of the first element in the range
- * @param len the length of the range; must be > 0
- * @param hint the index at which to begin the search, 0 <= hint < n.
- * The closer hint is to the result, the faster this method will run.
- * @param c the comparator used to order the range, and to search
- * @return the int k, 0 <= k <= n such that a[b + k - 1] < key <= a[b + k],
- * pretending that a[b - 1] is minus infinity and a[b + n] is infinity.
- * In other words, key belongs at index b + k; or in other words,
- * the first k elements of a should precede key, and the last n - k
- * should follow it.
- */
- private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
- assert len > 0 && hint >= 0 && hint < len;
- int lastOfs = 0;
- int ofs = 1;
- if (c.compare(key, s.getKey(a, base + hint)) > 0) {
- // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
- int maxOfs = len - hint;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to base
- lastOfs += hint;
- ofs += hint;
- } else { // key <= a[base + hint]
- // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
- final int maxOfs = hint + 1;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to base
- int tmp = lastOfs;
- lastOfs = hint - ofs;
- ofs = hint - tmp;
- }
- assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
-
- /*
- * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere
- * to the right of lastOfs but no farther right than ofs. Do a binary
- * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs].
- */
- lastOfs++;
- while (lastOfs < ofs) {
- int m = lastOfs + ((ofs - lastOfs) >>> 1);
-
- if (c.compare(key, s.getKey(a, base + m)) > 0)
- lastOfs = m + 1; // a[base + m] < key
- else
- ofs = m; // key <= a[base + m]
- }
- assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs]
- return ofs;
- }
-
- /**
- * Like gallopLeft, except that if the range contains an element equal to
- * key, gallopRight returns the index after the rightmost equal element.
- *
- * @param key the key whose insertion point to search for
- * @param a the array in which to search
- * @param base the index of the first element in the range
- * @param len the length of the range; must be > 0
- * @param hint the index at which to begin the search, 0 <= hint < n.
- * The closer hint is to the result, the faster this method will run.
- * @param c the comparator used to order the range, and to search
- * @return the int k, 0 <= k <= n such that a[b + k - 1] <= key < a[b + k]
- */
- private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
- assert len > 0 && hint >= 0 && hint < len;
-
- int ofs = 1;
- int lastOfs = 0;
- if (c.compare(key, s.getKey(a, base + hint)) < 0) {
- // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
- int maxOfs = hint + 1;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to b
- int tmp = lastOfs;
- lastOfs = hint - ofs;
- ofs = hint - tmp;
- } else { // a[b + hint] <= key
- // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
- int maxOfs = len - hint;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to b
- lastOfs += hint;
- ofs += hint;
- }
- assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
-
- /*
- * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to
- * the right of lastOfs but no farther right than ofs. Do a binary
- * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs].
- */
- lastOfs++;
- while (lastOfs < ofs) {
- int m = lastOfs + ((ofs - lastOfs) >>> 1);
-
- if (c.compare(key, s.getKey(a, base + m)) < 0)
- ofs = m; // key < a[b + m]
- else
- lastOfs = m + 1; // a[b + m] <= key
- }
- assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs]
- return ofs;
- }
-
- /**
- * Merges two adjacent runs in place, in a stable fashion. The first
- * element of the first run must be greater than the first element of the
- * second run (a[base1] > a[base2]), and the last element of the first run
- * (a[base1 + len1-1]) must be greater than all elements of the second run.
- *
- * For performance, this method should be called only when len1 <= len2;
- * its twin, mergeHi should be called if len1 >= len2. (Either method
- * may be called if len1 == len2.)
- *
- * @param base1 index of first element in first run to be merged
- * @param len1 length of first run to be merged (must be > 0)
- * @param base2 index of first element in second run to be merged
- * (must be aBase + aLen)
- * @param len2 length of second run to be merged (must be > 0)
- */
- private void mergeLo(int base1, int len1, int base2, int len2) {
- assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
-
- // Copy first run into temp array
- Buffer a = this.a; // For performance
- Buffer tmp = ensureCapacity(len1);
- s.copyRange(a, base1, tmp, 0, len1);
-
- int cursor1 = 0; // Indexes into tmp array
- int cursor2 = base2; // Indexes int a
- int dest = base1; // Indexes int a
-
- // Move first element of second run and deal with degenerate cases
- s.copyElement(a, cursor2++, a, dest++);
- if (--len2 == 0) {
- s.copyRange(tmp, cursor1, a, dest, len1);
- return;
- }
- if (len1 == 1) {
- s.copyRange(a, cursor2, a, dest, len2);
- s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
- return;
- }
-
- Comparator super K> c = this.c; // Use local variable for performance
- int minGallop = this.minGallop; // " " " " "
- outer:
- while (true) {
- int count1 = 0; // Number of times in a row that first run won
- int count2 = 0; // Number of times in a row that second run won
-
- /*
- * Do the straightforward thing until (if ever) one run starts
- * winning consistently.
- */
- do {
- assert len1 > 1 && len2 > 0;
- if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) {
- s.copyElement(a, cursor2++, a, dest++);
- count2++;
- count1 = 0;
- if (--len2 == 0)
- break outer;
- } else {
- s.copyElement(tmp, cursor1++, a, dest++);
- count1++;
- count2 = 0;
- if (--len1 == 1)
- break outer;
- }
- } while ((count1 | count2) < minGallop);
-
- /*
- * One run is winning so consistently that galloping may be a
- * huge win. So try that, and continue galloping until (if ever)
- * neither run appears to be winning consistently anymore.
- */
- do {
- assert len1 > 1 && len2 > 0;
- count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c);
- if (count1 != 0) {
- s.copyRange(tmp, cursor1, a, dest, count1);
- dest += count1;
- cursor1 += count1;
- len1 -= count1;
- if (len1 <= 1) // len1 == 1 || len1 == 0
- break outer;
- }
- s.copyElement(a, cursor2++, a, dest++);
- if (--len2 == 0)
- break outer;
-
- count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c);
- if (count2 != 0) {
- s.copyRange(a, cursor2, a, dest, count2);
- dest += count2;
- cursor2 += count2;
- len2 -= count2;
- if (len2 == 0)
- break outer;
- }
- s.copyElement(tmp, cursor1++, a, dest++);
- if (--len1 == 1)
- break outer;
- minGallop--;
- } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
- if (minGallop < 0)
- minGallop = 0;
- minGallop += 2; // Penalize for leaving gallop mode
- } // End of "outer" loop
- this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
-
- if (len1 == 1) {
- assert len2 > 0;
- s.copyRange(a, cursor2, a, dest, len2);
- s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
- } else if (len1 == 0) {
- throw new IllegalArgumentException(
- "Comparison method violates its general contract!");
- } else {
- assert len2 == 0;
- assert len1 > 1;
- s.copyRange(tmp, cursor1, a, dest, len1);
- }
- }
-
- /**
- * Like mergeLo, except that this method should be called only if
- * len1 >= len2; mergeLo should be called if len1 <= len2. (Either method
- * may be called if len1 == len2.)
- *
- * @param base1 index of first element in first run to be merged
- * @param len1 length of first run to be merged (must be > 0)
- * @param base2 index of first element in second run to be merged
- * (must be aBase + aLen)
- * @param len2 length of second run to be merged (must be > 0)
- */
- private void mergeHi(int base1, int len1, int base2, int len2) {
- assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
-
- // Copy second run into temp array
- Buffer a = this.a; // For performance
- Buffer tmp = ensureCapacity(len2);
- s.copyRange(a, base2, tmp, 0, len2);
-
- int cursor1 = base1 + len1 - 1; // Indexes into a
- int cursor2 = len2 - 1; // Indexes into tmp array
- int dest = base2 + len2 - 1; // Indexes into a
-
- // Move last element of first run and deal with degenerate cases
- s.copyElement(a, cursor1--, a, dest--);
- if (--len1 == 0) {
- s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
- return;
- }
- if (len2 == 1) {
- dest -= len1;
- cursor1 -= len1;
- s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
- s.copyElement(tmp, cursor2, a, dest);
- return;
- }
-
- Comparator super K> c = this.c; // Use local variable for performance
- int minGallop = this.minGallop; // " " " " "
- outer:
- while (true) {
- int count1 = 0; // Number of times in a row that first run won
- int count2 = 0; // Number of times in a row that second run won
-
- /*
- * Do the straightforward thing until (if ever) one run
- * appears to win consistently.
- */
- do {
- assert len1 > 0 && len2 > 1;
- if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) {
- s.copyElement(a, cursor1--, a, dest--);
- count1++;
- count2 = 0;
- if (--len1 == 0)
- break outer;
- } else {
- s.copyElement(tmp, cursor2--, a, dest--);
- count2++;
- count1 = 0;
- if (--len2 == 1)
- break outer;
- }
- } while ((count1 | count2) < minGallop);
-
- /*
- * One run is winning so consistently that galloping may be a
- * huge win. So try that, and continue galloping until (if ever)
- * neither run appears to be winning consistently anymore.
- */
- do {
- assert len1 > 0 && len2 > 1;
- count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c);
- if (count1 != 0) {
- dest -= count1;
- cursor1 -= count1;
- len1 -= count1;
- s.copyRange(a, cursor1 + 1, a, dest + 1, count1);
- if (len1 == 0)
- break outer;
- }
- s.copyElement(tmp, cursor2--, a, dest--);
- if (--len2 == 1)
- break outer;
-
- count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c);
- if (count2 != 0) {
- dest -= count2;
- cursor2 -= count2;
- len2 -= count2;
- s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2);
- if (len2 <= 1) // len2 == 1 || len2 == 0
- break outer;
- }
- s.copyElement(a, cursor1--, a, dest--);
- if (--len1 == 0)
- break outer;
- minGallop--;
- } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
- if (minGallop < 0)
- minGallop = 0;
- minGallop += 2; // Penalize for leaving gallop mode
- } // End of "outer" loop
- this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
-
- if (len2 == 1) {
- assert len1 > 0;
- dest -= len1;
- cursor1 -= len1;
- s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
- s.copyElement(tmp, cursor2, a, dest); // Move first elt of run2 to front of merge
- } else if (len2 == 0) {
- throw new IllegalArgumentException(
- "Comparison method violates its general contract!");
- } else {
- assert len1 == 0;
- assert len2 > 0;
- s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
- }
- }
-
- /**
- * Ensures that the external array tmp has at least the specified
- * number of elements, increasing its size if necessary. The size
- * increases exponentially to ensure amortized linear time complexity.
- *
- * @param minCapacity the minimum required capacity of the tmp array
- * @return tmp, whether or not it grew
- */
- private Buffer ensureCapacity(int minCapacity) {
- if (tmpLength < minCapacity) {
- // Compute smallest power of 2 > minCapacity
- int newSize = minCapacity;
- newSize |= newSize >> 1;
- newSize |= newSize >> 2;
- newSize |= newSize >> 4;
- newSize |= newSize >> 8;
- newSize |= newSize >> 16;
- newSize++;
-
- if (newSize < 0) // Not bloody likely!
- newSize = minCapacity;
- else
- newSize = Math.min(newSize, aLength >>> 1);
-
- tmp = s.allocate(newSize);
- tmpLength = newSize;
- }
- return tmp;
- }
- }
-}
diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
new file mode 100644
index 0000000000000..409e1a41c5d49
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
@@ -0,0 +1,940 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection;
+
+import java.util.Comparator;
+
+/**
+ * A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort."
+ * See the method comment on sort() for more details.
+ *
+ * This has been kept in Java with the original style in order to match very closely with the
+ * Android source code, and thus be easy to verify correctness. The class is package private. We put
+ * a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to
+ * package org.apache.spark.
+ *
+ * The purpose of the port is to generalize the interface to the sort to accept input data formats
+ * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
+ * uses this to sort an Array with alternating elements of the form [key, value, key, value].
+ * This generalization comes with minimal overhead -- see SortDataFormat for more information.
+ *
+ * We allow key reuse to prevent creating many key objects -- see SortDataFormat.
+ *
+ * @see org.apache.spark.util.collection.SortDataFormat
+ * @see org.apache.spark.util.collection.Sorter
+ */
+class TimSort {
+
+ /**
+ * This is the minimum sized sequence that will be merged. Shorter
+ * sequences will be lengthened by calling binarySort. If the entire
+ * array is less than this length, no merges will be performed.
+ *
+ * This constant should be a power of two. It was 64 in Tim Peter's C
+ * implementation, but 32 was empirically determined to work better in
+ * this implementation. In the unlikely event that you set this constant
+ * to be a number that's not a power of two, you'll need to change the
+ * minRunLength computation.
+ *
+ * If you decrease this constant, you must change the stackLen
+ * computation in the TimSort constructor, or you risk an
+ * ArrayOutOfBounds exception. See listsort.txt for a discussion
+ * of the minimum stack length required as a function of the length
+ * of the array being sorted and the minimum merge sequence length.
+ */
+ private static final int MIN_MERGE = 32;
+
+ private final SortDataFormat s;
+
+ public TimSort(SortDataFormat sortDataFormat) {
+ this.s = sortDataFormat;
+ }
+
+ /**
+ * A stable, adaptive, iterative mergesort that requires far fewer than
+ * n lg(n) comparisons when running on partially sorted arrays, while
+ * offering performance comparable to a traditional mergesort when run
+ * on random arrays. Like all proper mergesorts, this sort is stable and
+ * runs O(n log n) time (worst case). In the worst case, this sort requires
+ * temporary storage space for n/2 object references; in the best case,
+ * it requires only a small constant amount of space.
+ *
+ * This implementation was adapted from Tim Peters's list sort for
+ * Python, which is described in detail here:
+ *
+ * http://svn.python.org/projects/python/trunk/Objects/listsort.txt
+ *
+ * Tim's C code may be found here:
+ *
+ * http://svn.python.org/projects/python/trunk/Objects/listobject.c
+ *
+ * The underlying techniques are described in this paper (and may have
+ * even earlier origins):
+ *
+ * "Optimistic Sorting and Information Theoretic Complexity"
+ * Peter McIlroy
+ * SODA (Fourth Annual ACM-SIAM Symposium on Discrete Algorithms),
+ * pp 467-474, Austin, Texas, 25-27 January 1993.
+ *
+ * While the API to this class consists solely of static methods, it is
+ * (privately) instantiable; a TimSort instance holds the state of an ongoing
+ * sort, assuming the input array is large enough to warrant the full-blown
+ * TimSort. Small arrays are sorted in place, using a binary insertion sort.
+ *
+ * @author Josh Bloch
+ */
+ public void sort(Buffer a, int lo, int hi, Comparator super K> c) {
+ assert c != null;
+
+ int nRemaining = hi - lo;
+ if (nRemaining < 2)
+ return; // Arrays of size 0 and 1 are always sorted
+
+ // If array is small, do a "mini-TimSort" with no merges
+ if (nRemaining < MIN_MERGE) {
+ int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
+ binarySort(a, lo, hi, lo + initRunLen, c);
+ return;
+ }
+
+ /**
+ * March over the array once, left to right, finding natural runs,
+ * extending short natural runs to minRun elements, and merging runs
+ * to maintain stack invariant.
+ */
+ SortState sortState = new SortState(a, c, hi - lo);
+ int minRun = minRunLength(nRemaining);
+ do {
+ // Identify next run
+ int runLen = countRunAndMakeAscending(a, lo, hi, c);
+
+ // If run is short, extend to min(minRun, nRemaining)
+ if (runLen < minRun) {
+ int force = nRemaining <= minRun ? nRemaining : minRun;
+ binarySort(a, lo, lo + force, lo + runLen, c);
+ runLen = force;
+ }
+
+ // Push run onto pending-run stack, and maybe merge
+ sortState.pushRun(lo, runLen);
+ sortState.mergeCollapse();
+
+ // Advance to find next run
+ lo += runLen;
+ nRemaining -= runLen;
+ } while (nRemaining != 0);
+
+ // Merge all remaining runs to complete sort
+ assert lo == hi;
+ sortState.mergeForceCollapse();
+ assert sortState.stackSize == 1;
+ }
+
+ /**
+ * Sorts the specified portion of the specified array using a binary
+ * insertion sort. This is the best method for sorting small numbers
+ * of elements. It requires O(n log n) compares, but O(n^2) data
+ * movement (worst case).
+ *
+ * If the initial part of the specified range is already sorted,
+ * this method can take advantage of it: the method assumes that the
+ * elements from index {@code lo}, inclusive, to {@code start},
+ * exclusive are already sorted.
+ *
+ * @param a the array in which a range is to be sorted
+ * @param lo the index of the first element in the range to be sorted
+ * @param hi the index after the last element in the range to be sorted
+ * @param start the index of the first element in the range that is
+ * not already known to be sorted ({@code lo <= start <= hi})
+ * @param c comparator to used for the sort
+ */
+ @SuppressWarnings("fallthrough")
+ private void binarySort(Buffer a, int lo, int hi, int start, Comparator super K> c) {
+ assert lo <= start && start <= hi;
+ if (start == lo)
+ start++;
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ Buffer pivotStore = s.allocate(1);
+ for ( ; start < hi; start++) {
+ s.copyElement(a, start, pivotStore, 0);
+ K pivot = s.getKey(pivotStore, 0, key0);
+
+ // Set left (and right) to the index where a[start] (pivot) belongs
+ int left = lo;
+ int right = start;
+ assert left <= right;
+ /*
+ * Invariants:
+ * pivot >= all in [lo, left).
+ * pivot < all in [right, start).
+ */
+ while (left < right) {
+ int mid = (left + right) >>> 1;
+ if (c.compare(pivot, s.getKey(a, mid, key1)) < 0)
+ right = mid;
+ else
+ left = mid + 1;
+ }
+ assert left == right;
+
+ /*
+ * The invariants still hold: pivot >= all in [lo, left) and
+ * pivot < all in [left, start), so pivot belongs at left. Note
+ * that if there are elements equal to pivot, left points to the
+ * first slot after them -- that's why this sort is stable.
+ * Slide elements over to make room for pivot.
+ */
+ int n = start - left; // The number of elements to move
+ // Switch is just an optimization for arraycopy in default case
+ switch (n) {
+ case 2: s.copyElement(a, left + 1, a, left + 2);
+ case 1: s.copyElement(a, left, a, left + 1);
+ break;
+ default: s.copyRange(a, left, a, left + 1, n);
+ }
+ s.copyElement(pivotStore, 0, a, left);
+ }
+ }
+
+ /**
+ * Returns the length of the run beginning at the specified position in
+ * the specified array and reverses the run if it is descending (ensuring
+ * that the run will always be ascending when the method returns).
+ *
+ * A run is the longest ascending sequence with:
+ *
+ * a[lo] <= a[lo + 1] <= a[lo + 2] <= ...
+ *
+ * or the longest descending sequence with:
+ *
+ * a[lo] > a[lo + 1] > a[lo + 2] > ...
+ *
+ * For its intended use in a stable mergesort, the strictness of the
+ * definition of "descending" is needed so that the call can safely
+ * reverse a descending sequence without violating stability.
+ *
+ * @param a the array in which a run is to be counted and possibly reversed
+ * @param lo index of the first element in the run
+ * @param hi index after the last element that may be contained in the run.
+ It is required that {@code lo < hi}.
+ * @param c the comparator to used for the sort
+ * @return the length of the run beginning at the specified position in
+ * the specified array
+ */
+ private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator super K> c) {
+ assert lo < hi;
+ int runHi = lo + 1;
+ if (runHi == hi)
+ return 1;
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ // Find end of run, and reverse range if descending
+ if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending
+ while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0)
+ runHi++;
+ reverseRange(a, lo, runHi);
+ } else { // Ascending
+ while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0)
+ runHi++;
+ }
+
+ return runHi - lo;
+ }
+
+ /**
+ * Reverse the specified range of the specified array.
+ *
+ * @param a the array in which a range is to be reversed
+ * @param lo the index of the first element in the range to be reversed
+ * @param hi the index after the last element in the range to be reversed
+ */
+ private void reverseRange(Buffer a, int lo, int hi) {
+ hi--;
+ while (lo < hi) {
+ s.swap(a, lo, hi);
+ lo++;
+ hi--;
+ }
+ }
+
+ /**
+ * Returns the minimum acceptable run length for an array of the specified
+ * length. Natural runs shorter than this will be extended with
+ * {@link #binarySort}.
+ *
+ * Roughly speaking, the computation is:
+ *
+ * If n < MIN_MERGE, return n (it's too small to bother with fancy stuff).
+ * Else if n is an exact power of 2, return MIN_MERGE/2.
+ * Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k
+ * is close to, but strictly less than, an exact power of 2.
+ *
+ * For the rationale, see listsort.txt.
+ *
+ * @param n the length of the array to be sorted
+ * @return the length of the minimum run to be merged
+ */
+ private int minRunLength(int n) {
+ assert n >= 0;
+ int r = 0; // Becomes 1 if any 1 bits are shifted off
+ while (n >= MIN_MERGE) {
+ r |= (n & 1);
+ n >>= 1;
+ }
+ return n + r;
+ }
+
+ private class SortState {
+
+ /**
+ * The Buffer being sorted.
+ */
+ private final Buffer a;
+
+ /**
+ * Length of the sort Buffer.
+ */
+ private final int aLength;
+
+ /**
+ * The comparator for this sort.
+ */
+ private final Comparator super K> c;
+
+ /**
+ * When we get into galloping mode, we stay there until both runs win less
+ * often than MIN_GALLOP consecutive times.
+ */
+ private static final int MIN_GALLOP = 7;
+
+ /**
+ * This controls when we get *into* galloping mode. It is initialized
+ * to MIN_GALLOP. The mergeLo and mergeHi methods nudge it higher for
+ * random data, and lower for highly structured data.
+ */
+ private int minGallop = MIN_GALLOP;
+
+ /**
+ * Maximum initial size of tmp array, which is used for merging. The array
+ * can grow to accommodate demand.
+ *
+ * Unlike Tim's original C version, we do not allocate this much storage
+ * when sorting smaller arrays. This change was required for performance.
+ */
+ private static final int INITIAL_TMP_STORAGE_LENGTH = 256;
+
+ /**
+ * Temp storage for merges.
+ */
+ private Buffer tmp; // Actual runtime type will be Object[], regardless of T
+
+ /**
+ * Length of the temp storage.
+ */
+ private int tmpLength = 0;
+
+ /**
+ * A stack of pending runs yet to be merged. Run i starts at
+ * address base[i] and extends for len[i] elements. It's always
+ * true (so long as the indices are in bounds) that:
+ *
+ * runBase[i] + runLen[i] == runBase[i + 1]
+ *
+ * so we could cut the storage for this, but it's a minor amount,
+ * and keeping all the info explicit simplifies the code.
+ */
+ private int stackSize = 0; // Number of pending runs on stack
+ private final int[] runBase;
+ private final int[] runLen;
+
+ /**
+ * Creates a TimSort instance to maintain the state of an ongoing sort.
+ *
+ * @param a the array to be sorted
+ * @param c the comparator to determine the order of the sort
+ */
+ private SortState(Buffer a, Comparator super K> c, int len) {
+ this.aLength = len;
+ this.a = a;
+ this.c = c;
+
+ // Allocate temp storage (which may be increased later if necessary)
+ tmpLength = len < 2 * INITIAL_TMP_STORAGE_LENGTH ? len >>> 1 : INITIAL_TMP_STORAGE_LENGTH;
+ tmp = s.allocate(tmpLength);
+
+ /*
+ * Allocate runs-to-be-merged stack (which cannot be expanded). The
+ * stack length requirements are described in listsort.txt. The C
+ * version always uses the same stack length (85), but this was
+ * measured to be too expensive when sorting "mid-sized" arrays (e.g.,
+ * 100 elements) in Java. Therefore, we use smaller (but sufficiently
+ * large) stack lengths for smaller arrays. The "magic numbers" in the
+ * computation below must be changed if MIN_MERGE is decreased. See
+ * the MIN_MERGE declaration above for more information.
+ */
+ int stackLen = (len < 120 ? 5 :
+ len < 1542 ? 10 :
+ len < 119151 ? 19 : 40);
+ runBase = new int[stackLen];
+ runLen = new int[stackLen];
+ }
+
+ /**
+ * Pushes the specified run onto the pending-run stack.
+ *
+ * @param runBase index of the first element in the run
+ * @param runLen the number of elements in the run
+ */
+ private void pushRun(int runBase, int runLen) {
+ this.runBase[stackSize] = runBase;
+ this.runLen[stackSize] = runLen;
+ stackSize++;
+ }
+
+ /**
+ * Examines the stack of runs waiting to be merged and merges adjacent runs
+ * until the stack invariants are reestablished:
+ *
+ * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
+ * 2. runLen[i - 2] > runLen[i - 1]
+ *
+ * This method is called each time a new run is pushed onto the stack,
+ * so the invariants are guaranteed to hold for i < stackSize upon
+ * entry to the method.
+ */
+ private void mergeCollapse() {
+ while (stackSize > 1) {
+ int n = stackSize - 2;
+ if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
+ if (runLen[n - 1] < runLen[n + 1])
+ n--;
+ mergeAt(n);
+ } else if (runLen[n] <= runLen[n + 1]) {
+ mergeAt(n);
+ } else {
+ break; // Invariant is established
+ }
+ }
+ }
+
+ /**
+ * Merges all runs on the stack until only one remains. This method is
+ * called once, to complete the sort.
+ */
+ private void mergeForceCollapse() {
+ while (stackSize > 1) {
+ int n = stackSize - 2;
+ if (n > 0 && runLen[n - 1] < runLen[n + 1])
+ n--;
+ mergeAt(n);
+ }
+ }
+
+ /**
+ * Merges the two runs at stack indices i and i+1. Run i must be
+ * the penultimate or antepenultimate run on the stack. In other words,
+ * i must be equal to stackSize-2 or stackSize-3.
+ *
+ * @param i stack index of the first of the two runs to merge
+ */
+ private void mergeAt(int i) {
+ assert stackSize >= 2;
+ assert i >= 0;
+ assert i == stackSize - 2 || i == stackSize - 3;
+
+ int base1 = runBase[i];
+ int len1 = runLen[i];
+ int base2 = runBase[i + 1];
+ int len2 = runLen[i + 1];
+ assert len1 > 0 && len2 > 0;
+ assert base1 + len1 == base2;
+
+ /*
+ * Record the length of the combined runs; if i is the 3rd-last
+ * run now, also slide over the last run (which isn't involved
+ * in this merge). The current run (i+1) goes away in any case.
+ */
+ runLen[i] = len1 + len2;
+ if (i == stackSize - 3) {
+ runBase[i + 1] = runBase[i + 2];
+ runLen[i + 1] = runLen[i + 2];
+ }
+ stackSize--;
+
+ K key0 = s.newKey();
+
+ /*
+ * Find where the first element of run2 goes in run1. Prior elements
+ * in run1 can be ignored (because they're already in place).
+ */
+ int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c);
+ assert k >= 0;
+ base1 += k;
+ len1 -= k;
+ if (len1 == 0)
+ return;
+
+ /*
+ * Find where the last element of run1 goes in run2. Subsequent elements
+ * in run2 can be ignored (because they're already in place).
+ */
+ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c);
+ assert len2 >= 0;
+ if (len2 == 0)
+ return;
+
+ // Merge remaining runs, using tmp array with min(len1, len2) elements
+ if (len1 <= len2)
+ mergeLo(base1, len1, base2, len2);
+ else
+ mergeHi(base1, len1, base2, len2);
+ }
+
+ /**
+ * Locates the position at which to insert the specified key into the
+ * specified sorted range; if the range contains an element equal to key,
+ * returns the index of the leftmost equal element.
+ *
+ * @param key the key whose insertion point to search for
+ * @param a the array in which to search
+ * @param base the index of the first element in the range
+ * @param len the length of the range; must be > 0
+ * @param hint the index at which to begin the search, 0 <= hint < n.
+ * The closer hint is to the result, the faster this method will run.
+ * @param c the comparator used to order the range, and to search
+ * @return the int k, 0 <= k <= n such that a[b + k - 1] < key <= a[b + k],
+ * pretending that a[b - 1] is minus infinity and a[b + n] is infinity.
+ * In other words, key belongs at index b + k; or in other words,
+ * the first k elements of a should precede key, and the last n - k
+ * should follow it.
+ */
+ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
+ assert len > 0 && hint >= 0 && hint < len;
+ int lastOfs = 0;
+ int ofs = 1;
+ K key0 = s.newKey();
+
+ if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) {
+ // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
+ int maxOfs = len - hint;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to base
+ lastOfs += hint;
+ ofs += hint;
+ } else { // key <= a[base + hint]
+ // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
+ final int maxOfs = hint + 1;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to base
+ int tmp = lastOfs;
+ lastOfs = hint - ofs;
+ ofs = hint - tmp;
+ }
+ assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
+
+ /*
+ * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere
+ * to the right of lastOfs but no farther right than ofs. Do a binary
+ * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs].
+ */
+ lastOfs++;
+ while (lastOfs < ofs) {
+ int m = lastOfs + ((ofs - lastOfs) >>> 1);
+
+ if (c.compare(key, s.getKey(a, base + m, key0)) > 0)
+ lastOfs = m + 1; // a[base + m] < key
+ else
+ ofs = m; // key <= a[base + m]
+ }
+ assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs]
+ return ofs;
+ }
+
+ /**
+ * Like gallopLeft, except that if the range contains an element equal to
+ * key, gallopRight returns the index after the rightmost equal element.
+ *
+ * @param key the key whose insertion point to search for
+ * @param a the array in which to search
+ * @param base the index of the first element in the range
+ * @param len the length of the range; must be > 0
+ * @param hint the index at which to begin the search, 0 <= hint < n.
+ * The closer hint is to the result, the faster this method will run.
+ * @param c the comparator used to order the range, and to search
+ * @return the int k, 0 <= k <= n such that a[b + k - 1] <= key < a[b + k]
+ */
+ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
+ assert len > 0 && hint >= 0 && hint < len;
+
+ int ofs = 1;
+ int lastOfs = 0;
+ K key1 = s.newKey();
+
+ if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) {
+ // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
+ int maxOfs = hint + 1;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to b
+ int tmp = lastOfs;
+ lastOfs = hint - ofs;
+ ofs = hint - tmp;
+ } else { // a[b + hint] <= key
+ // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
+ int maxOfs = len - hint;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to b
+ lastOfs += hint;
+ ofs += hint;
+ }
+ assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
+
+ /*
+ * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to
+ * the right of lastOfs but no farther right than ofs. Do a binary
+ * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs].
+ */
+ lastOfs++;
+ while (lastOfs < ofs) {
+ int m = lastOfs + ((ofs - lastOfs) >>> 1);
+
+ if (c.compare(key, s.getKey(a, base + m, key1)) < 0)
+ ofs = m; // key < a[b + m]
+ else
+ lastOfs = m + 1; // a[b + m] <= key
+ }
+ assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs]
+ return ofs;
+ }
+
+ /**
+ * Merges two adjacent runs in place, in a stable fashion. The first
+ * element of the first run must be greater than the first element of the
+ * second run (a[base1] > a[base2]), and the last element of the first run
+ * (a[base1 + len1-1]) must be greater than all elements of the second run.
+ *
+ * For performance, this method should be called only when len1 <= len2;
+ * its twin, mergeHi should be called if len1 >= len2. (Either method
+ * may be called if len1 == len2.)
+ *
+ * @param base1 index of first element in first run to be merged
+ * @param len1 length of first run to be merged (must be > 0)
+ * @param base2 index of first element in second run to be merged
+ * (must be aBase + aLen)
+ * @param len2 length of second run to be merged (must be > 0)
+ */
+ private void mergeLo(int base1, int len1, int base2, int len2) {
+ assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
+
+ // Copy first run into temp array
+ Buffer a = this.a; // For performance
+ Buffer tmp = ensureCapacity(len1);
+ s.copyRange(a, base1, tmp, 0, len1);
+
+ int cursor1 = 0; // Indexes into tmp array
+ int cursor2 = base2; // Indexes int a
+ int dest = base1; // Indexes int a
+
+ // Move first element of second run and deal with degenerate cases
+ s.copyElement(a, cursor2++, a, dest++);
+ if (--len2 == 0) {
+ s.copyRange(tmp, cursor1, a, dest, len1);
+ return;
+ }
+ if (len1 == 1) {
+ s.copyRange(a, cursor2, a, dest, len2);
+ s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
+ return;
+ }
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ Comparator super K> c = this.c; // Use local variable for performance
+ int minGallop = this.minGallop; // " " " " "
+ outer:
+ while (true) {
+ int count1 = 0; // Number of times in a row that first run won
+ int count2 = 0; // Number of times in a row that second run won
+
+ /*
+ * Do the straightforward thing until (if ever) one run starts
+ * winning consistently.
+ */
+ do {
+ assert len1 > 1 && len2 > 0;
+ if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) {
+ s.copyElement(a, cursor2++, a, dest++);
+ count2++;
+ count1 = 0;
+ if (--len2 == 0)
+ break outer;
+ } else {
+ s.copyElement(tmp, cursor1++, a, dest++);
+ count1++;
+ count2 = 0;
+ if (--len1 == 1)
+ break outer;
+ }
+ } while ((count1 | count2) < minGallop);
+
+ /*
+ * One run is winning so consistently that galloping may be a
+ * huge win. So try that, and continue galloping until (if ever)
+ * neither run appears to be winning consistently anymore.
+ */
+ do {
+ assert len1 > 1 && len2 > 0;
+ count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c);
+ if (count1 != 0) {
+ s.copyRange(tmp, cursor1, a, dest, count1);
+ dest += count1;
+ cursor1 += count1;
+ len1 -= count1;
+ if (len1 <= 1) // len1 == 1 || len1 == 0
+ break outer;
+ }
+ s.copyElement(a, cursor2++, a, dest++);
+ if (--len2 == 0)
+ break outer;
+
+ count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c);
+ if (count2 != 0) {
+ s.copyRange(a, cursor2, a, dest, count2);
+ dest += count2;
+ cursor2 += count2;
+ len2 -= count2;
+ if (len2 == 0)
+ break outer;
+ }
+ s.copyElement(tmp, cursor1++, a, dest++);
+ if (--len1 == 1)
+ break outer;
+ minGallop--;
+ } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
+ if (minGallop < 0)
+ minGallop = 0;
+ minGallop += 2; // Penalize for leaving gallop mode
+ } // End of "outer" loop
+ this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
+
+ if (len1 == 1) {
+ assert len2 > 0;
+ s.copyRange(a, cursor2, a, dest, len2);
+ s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
+ } else if (len1 == 0) {
+ throw new IllegalArgumentException(
+ "Comparison method violates its general contract!");
+ } else {
+ assert len2 == 0;
+ assert len1 > 1;
+ s.copyRange(tmp, cursor1, a, dest, len1);
+ }
+ }
+
+ /**
+ * Like mergeLo, except that this method should be called only if
+ * len1 >= len2; mergeLo should be called if len1 <= len2. (Either method
+ * may be called if len1 == len2.)
+ *
+ * @param base1 index of first element in first run to be merged
+ * @param len1 length of first run to be merged (must be > 0)
+ * @param base2 index of first element in second run to be merged
+ * (must be aBase + aLen)
+ * @param len2 length of second run to be merged (must be > 0)
+ */
+ private void mergeHi(int base1, int len1, int base2, int len2) {
+ assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
+
+ // Copy second run into temp array
+ Buffer a = this.a; // For performance
+ Buffer tmp = ensureCapacity(len2);
+ s.copyRange(a, base2, tmp, 0, len2);
+
+ int cursor1 = base1 + len1 - 1; // Indexes into a
+ int cursor2 = len2 - 1; // Indexes into tmp array
+ int dest = base2 + len2 - 1; // Indexes into a
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ // Move last element of first run and deal with degenerate cases
+ s.copyElement(a, cursor1--, a, dest--);
+ if (--len1 == 0) {
+ s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
+ return;
+ }
+ if (len2 == 1) {
+ dest -= len1;
+ cursor1 -= len1;
+ s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
+ s.copyElement(tmp, cursor2, a, dest);
+ return;
+ }
+
+ Comparator super K> c = this.c; // Use local variable for performance
+ int minGallop = this.minGallop; // " " " " "
+ outer:
+ while (true) {
+ int count1 = 0; // Number of times in a row that first run won
+ int count2 = 0; // Number of times in a row that second run won
+
+ /*
+ * Do the straightforward thing until (if ever) one run
+ * appears to win consistently.
+ */
+ do {
+ assert len1 > 0 && len2 > 1;
+ if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) {
+ s.copyElement(a, cursor1--, a, dest--);
+ count1++;
+ count2 = 0;
+ if (--len1 == 0)
+ break outer;
+ } else {
+ s.copyElement(tmp, cursor2--, a, dest--);
+ count2++;
+ count1 = 0;
+ if (--len2 == 1)
+ break outer;
+ }
+ } while ((count1 | count2) < minGallop);
+
+ /*
+ * One run is winning so consistently that galloping may be a
+ * huge win. So try that, and continue galloping until (if ever)
+ * neither run appears to be winning consistently anymore.
+ */
+ do {
+ assert len1 > 0 && len2 > 1;
+ count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c);
+ if (count1 != 0) {
+ dest -= count1;
+ cursor1 -= count1;
+ len1 -= count1;
+ s.copyRange(a, cursor1 + 1, a, dest + 1, count1);
+ if (len1 == 0)
+ break outer;
+ }
+ s.copyElement(tmp, cursor2--, a, dest--);
+ if (--len2 == 1)
+ break outer;
+
+ count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c);
+ if (count2 != 0) {
+ dest -= count2;
+ cursor2 -= count2;
+ len2 -= count2;
+ s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2);
+ if (len2 <= 1) // len2 == 1 || len2 == 0
+ break outer;
+ }
+ s.copyElement(a, cursor1--, a, dest--);
+ if (--len1 == 0)
+ break outer;
+ minGallop--;
+ } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
+ if (minGallop < 0)
+ minGallop = 0;
+ minGallop += 2; // Penalize for leaving gallop mode
+ } // End of "outer" loop
+ this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
+
+ if (len2 == 1) {
+ assert len1 > 0;
+ dest -= len1;
+ cursor1 -= len1;
+ s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
+ s.copyElement(tmp, cursor2, a, dest); // Move first elt of run2 to front of merge
+ } else if (len2 == 0) {
+ throw new IllegalArgumentException(
+ "Comparison method violates its general contract!");
+ } else {
+ assert len1 == 0;
+ assert len2 > 0;
+ s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
+ }
+ }
+
+ /**
+ * Ensures that the external array tmp has at least the specified
+ * number of elements, increasing its size if necessary. The size
+ * increases exponentially to ensure amortized linear time complexity.
+ *
+ * @param minCapacity the minimum required capacity of the tmp array
+ * @return tmp, whether or not it grew
+ */
+ private Buffer ensureCapacity(int minCapacity) {
+ if (tmpLength < minCapacity) {
+ // Compute smallest power of 2 > minCapacity
+ int newSize = minCapacity;
+ newSize |= newSize >> 1;
+ newSize |= newSize >> 2;
+ newSize |= newSize >> 4;
+ newSize |= newSize >> 8;
+ newSize |= newSize >> 16;
+ newSize++;
+
+ if (newSize < 0) // Not bloody likely!
+ newSize = minCapacity;
+ else
+ newSize = Math.min(newSize, aLength >>> 1);
+
+ tmp = s.allocate(newSize);
+ tmpLength = newSize;
+ }
+ return tmp;
+ }
+ }
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
new file mode 100644
index 0000000000000..14ba37d7c9bd9
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Register functions to show/hide columns based on checkboxes. These need
+ * to be registered after the page loads. */
+$(function() {
+ $("span.expand-additional-metrics").click(function(){
+ // Expand the list of additional metrics.
+ var additionalMetricsDiv = $(this).parent().find('.additional-metrics');
+ $(additionalMetricsDiv).toggleClass('collapsed');
+
+ // Switch the class of the arrow from open to closed.
+ $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open');
+ $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed');
+ });
+
+ stripeSummaryTable();
+
+ $("input:checkbox").click(function() {
+ var column = "table ." + $(this).attr("name");
+ $(column).toggle();
+ stripeSummaryTable();
+ });
+
+ $("#select-all-metrics").click(function() {
+ if (this.checked) {
+ // Toggle all un-checked options.
+ $('input:checkbox:not(:checked)').trigger('click');
+ } else {
+ // Toggle all checked options.
+ $('input:checkbox:checked').trigger('click');
+ }
+ });
+
+ // Trigger a click on the checkbox if a user clicks the label next to it.
+ $("span.additional-metric-title").click(function() {
+ $(this).parent().find('input:checkbox').trigger('click');
+ });
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js
new file mode 100644
index 0000000000000..656147e40d13e
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/table.js
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Adds background colors to stripe table rows in the summary table (on the stage page). This is
+ * necessary (instead of using css or the table striping provided by bootstrap) because the summary
+ * table has hidden rows.
+ *
+ * An ID selector (rather than a class selector) is used to ensure this runs quickly even on pages
+ * with thousands of task rows (ID selectors are much faster than class selectors). */
+function stripeSummaryTable() {
+ $("#task-summary-table").find("tr:not(:hidden)").each(function (index) {
+ if (index % 2 == 1) {
+ $(this).css("background-color", "#f9f9f9");
+ } else {
+ $(this).css("background-color", "#ffffff");
+ }
+ });
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index 152bde5f6994f..f23ba9dba167f 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -19,6 +19,7 @@
height: 50px;
font-size: 15px;
margin-bottom: 15px;
+ min-width: 1200px
}
.navbar .navbar-inner {
@@ -39,12 +40,12 @@
.navbar .nav > li a {
height: 30px;
- line-height: 30px;
+ line-height: 2;
}
.navbar-text {
height: 50px;
- line-height: 50px;
+ line-height: 3.3;
}
table.sortable thead {
@@ -120,7 +121,76 @@ pre {
border: none;
}
+.description-input {
+ overflow: hidden;
+ text-overflow: ellipsis;
+ width: 100%;
+ white-space: nowrap;
+ display: block;
+}
+
+.stacktrace-details {
+ max-height: 300px;
+ overflow-y: auto;
+ margin: 0;
+ transition: max-height 0.5s ease-out, padding 0.5s ease-out;
+}
+
+.stacktrace-details.collapsed {
+ max-height: 0;
+ padding-top: 0;
+ padding-bottom: 0;
+ border: none;
+}
+
+span.expand-additional-metrics {
+ cursor: pointer;
+}
+
+span.additional-metric-title {
+ cursor: pointer;
+}
+
+.additional-metrics.collapsed {
+ display: none;
+}
+
.tooltip {
font-weight: normal;
}
+.arrow-open {
+ width: 0;
+ height: 0;
+ border-left: 5px solid transparent;
+ border-right: 5px solid transparent;
+ border-top: 5px solid black;
+ float: left;
+ margin-top: 6px;
+}
+
+.arrow-closed {
+ width: 0;
+ height: 0;
+ border-top: 5px solid transparent;
+ border-bottom: 5px solid transparent;
+ border-left: 5px solid black;
+ display: inline-block;
+}
+
+.version {
+ line-height: 2.5;
+ vertical-align: bottom;
+ font-size: 12px;
+ padding: 0;
+ margin: 0;
+ font-weight: bold;
+ color: #777;
+}
+
+/* Hide all additional metrics by default. This is done here rather than using JavaScript to
+ * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
+.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time,
+.getting_result_time {
+ display: none;
+}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 12f2fe031cb1d..5f31bfba3f8d6 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -18,12 +18,15 @@
package org.apache.spark
import java.io.{ObjectInputStream, Serializable}
+import java.util.concurrent.atomic.AtomicLong
+import java.lang.ThreadLocal
import scala.collection.generic.Growable
import scala.collection.mutable.Map
import scala.reflect.ClassTag
import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.util.Utils
/**
* A data type that can be accumulated, ie has an commutative and associative "add" operation,
@@ -126,7 +129,7 @@ class Accumulable[R, T] (
}
// Called by Java when deserializing an object
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
value_ = zero
deserialized = true
@@ -227,6 +230,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
*/
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
extends Accumulable[T,T](initialValue, param, name) {
+
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)
}
@@ -243,15 +247,47 @@ trait AccumulatorParam[T] extends AccumulableParam[T, T] {
}
}
+object AccumulatorParam {
+
+ // The following implicit objects were in SparkContext before 1.2 and users had to
+ // `import SparkContext._` to enable them. Now we move them here to make the compiler find
+ // them automatically. However, as there are duplicate codes in SparkContext for backward
+ // compatibility, please update them accordingly if you modify the following implicit objects.
+
+ implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
+ def addInPlace(t1: Double, t2: Double): Double = t1 + t2
+ def zero(initialValue: Double) = 0.0
+ }
+
+ implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
+ def addInPlace(t1: Int, t2: Int): Int = t1 + t2
+ def zero(initialValue: Int) = 0
+ }
+
+ implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
+ def addInPlace(t1: Long, t2: Long) = t1 + t2
+ def zero(initialValue: Long) = 0L
+ }
+
+ implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
+ def addInPlace(t1: Float, t2: Float) = t1 + t2
+ def zero(initialValue: Float) = 0f
+ }
+
+ // TODO: Add AccumulatorParams for other types, e.g. lists and strings
+}
+
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
-private object Accumulators {
+private[spark] object Accumulators {
// TODO: Use soft references? => need to make readObject work properly then
val originals = Map[Long, Accumulable[_, _]]()
- val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
+ val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
+ override protected def initialValue() = Map[Long, Accumulable[_, _]]()
+ }
var lastId: Long = 0
- def newId: Long = synchronized {
+ def newId(): Long = synchronized {
lastId += 1
lastId
}
@@ -260,22 +296,21 @@ private object Accumulators {
if (original) {
originals(a.id) = a
} else {
- val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
- accums(a.id) = a
+ localAccums.get()(a.id) = a
}
}
// Clear the local (non-original) accumulators for the current thread
def clear() {
synchronized {
- localAccums.remove(Thread.currentThread)
+ localAccums.get.clear
}
}
// Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
- for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
+ for ((id, accum) <- localAccums.get) {
ret(id) = accum.localValue
}
return ret
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 79c9c451d273d..3b684bbeceaf2 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -34,7 +34,9 @@ case class Aggregator[K, V, C] (
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
- private val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
+ // When spilling is enabled sorting will happen externally, but not necessarily with an
+ // ExternalSorter.
+ private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
@deprecated("use combineValuesByKey with TaskContext argument", "0.9.0")
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] =
@@ -42,7 +44,7 @@ case class Aggregator[K, V, C] (
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
- if (!externalSorting) {
+ if (!isSpillEnabled) {
val combiners = new AppendOnlyMap[K,C]
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
@@ -59,8 +61,8 @@ case class Aggregator[K, V, C] (
// Update task metrics if context is not null
// TODO: Make context non optional in a future release
Option(context).foreach { c =>
- c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
- c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
+ c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
}
combiners.iterator
}
@@ -71,9 +73,9 @@ case class Aggregator[K, V, C] (
combineCombinersByKey(iter, null)
def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext)
- : Iterator[(K, C)] =
+ : Iterator[(K, C)] =
{
- if (!externalSorting) {
+ if (!isSpillEnabled) {
val combiners = new AppendOnlyMap[K,C]
var kc: Product2[K, C] = null
val update = (hadValue: Boolean, oldValue: C) => {
@@ -93,8 +95,8 @@ case class Aggregator[K, V, C] (
// Update task metrics if context is not null
// TODO: Make context non-optional in a future release
Option(context).foreach { c =>
- c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
- c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
+ c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
}
combiners.iterator
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 80da62c44edc5..a0c0372b7f0ef 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -44,7 +44,11 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
blockManager.get(key) match {
case Some(blockResult) =>
// Partition is already materialized, so just return its values
- context.taskMetrics.inputMetrics = Some(blockResult.inputMetrics)
+ val inputMetrics = blockResult.inputMetrics
+ val existingMetrics = context.taskMetrics
+ .getInputMetricsForReadMethod(inputMetrics.readMethod)
+ existingMetrics.addBytesRead(inputMetrics.bytesRead)
+
new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
case None =>
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index ab2594cfc02eb..9a7cd4523e5ab 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -60,6 +60,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
+ * @param keyOrdering key ordering for RDD's shuffles
+ * @param aggregator map/reduce-side aggregator for RDD's shuffle
+ * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine)
*/
@DeveloperApi
class ShuffleDependency[K, V, C](
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
new file mode 100644
index 0000000000000..a46a81eabd965
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * A client that communicates with the cluster manager to request or kill executors.
+ */
+private[spark] trait ExecutorAllocationClient {
+
+ /**
+ * Request an additional number of executors from the cluster manager.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ def requestExecutors(numAdditionalExecutors: Int): Boolean
+
+ /**
+ * Request that the cluster manager kill the specified executors.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ def killExecutors(executorIds: Seq[String]): Boolean
+
+ /**
+ * Request that the cluster manager kill the specified executor.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId))
+}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
new file mode 100644
index 0000000000000..5d5288bb6e60d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -0,0 +1,557 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable
+
+import org.apache.spark.scheduler._
+
+/**
+ * An agent that dynamically allocates and removes executors based on the workload.
+ *
+ * The add policy depends on whether there are backlogged tasks waiting to be scheduled. If
+ * the scheduler queue is not drained in N seconds, then new executors are added. If the queue
+ * persists for another M seconds, then more executors are added and so on. The number added
+ * in each round increases exponentially from the previous round until an upper bound on the
+ * number of executors has been reached. The upper bound is based both on a configured property
+ * and on the number of tasks pending: the policy will never increase the number of executor
+ * requests past the number needed to handle all pending tasks.
+ *
+ * The rationale for the exponential increase is twofold: (1) Executors should be added slowly
+ * in the beginning in case the number of extra executors needed turns out to be small. Otherwise,
+ * we may add more executors than we need just to remove them later. (2) Executors should be added
+ * quickly over time in case the maximum number of executors is very high. Otherwise, it will take
+ * a long time to ramp up under heavy workloads.
+ *
+ * The remove policy is simpler: If an executor has been idle for K seconds, meaning it has not
+ * been scheduled to run any tasks, then it is removed.
+ *
+ * There is no retry logic in either case because we make the assumption that the cluster manager
+ * will eventually fulfill all requests it receives asynchronously.
+ *
+ * The relevant Spark properties include the following:
+ *
+ * spark.dynamicAllocation.enabled - Whether this feature is enabled
+ * spark.dynamicAllocation.minExecutors - Lower bound on the number of executors
+ * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors
+ * spark.dynamicAllocation.initialExecutors - Number of executors to start with
+ *
+ * spark.dynamicAllocation.schedulerBacklogTimeout (M) -
+ * If there are backlogged tasks for this duration, add new executors
+ *
+ * spark.dynamicAllocation.sustainedSchedulerBacklogTimeout (N) -
+ * If the backlog is sustained for this duration, add more executors
+ * This is used only after the initial backlog timeout is exceeded
+ *
+ * spark.dynamicAllocation.executorIdleTimeout (K) -
+ * If an executor has been idle for this duration, remove it
+ */
+private[spark] class ExecutorAllocationManager(
+ client: ExecutorAllocationClient,
+ listenerBus: LiveListenerBus,
+ conf: SparkConf)
+ extends Logging {
+
+ allocationManager =>
+
+ import ExecutorAllocationManager._
+
+ // Lower and upper bounds on the number of executors.
+ private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0)
+ private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors",
+ Integer.MAX_VALUE)
+
+ // How long there must be backlogged tasks for before an addition is triggered
+ private val schedulerBacklogTimeout = conf.getLong(
+ "spark.dynamicAllocation.schedulerBacklogTimeout", 60)
+
+ // Same as above, but used only after `schedulerBacklogTimeout` is exceeded
+ private val sustainedSchedulerBacklogTimeout = conf.getLong(
+ "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout)
+
+ // How long an executor must be idle for before it is removed
+ private val executorIdleTimeout = conf.getLong(
+ "spark.dynamicAllocation.executorIdleTimeout", 600)
+
+ // During testing, the methods to actually kill and add executors are mocked out
+ private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
+
+ // TODO: The default value of 1 for spark.executor.cores works right now because dynamic
+ // allocation is only supported for YARN and the default number of cores per executor in YARN is
+ // 1, but it might need to be attained differently for different cluster managers
+ private val tasksPerExecutor =
+ conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1)
+
+ validateSettings()
+
+ // Number of executors to add in the next round
+ private var numExecutorsToAdd = 1
+
+ // Number of executors that have been requested but have not registered yet
+ private var numExecutorsPending = 0
+
+ // Executors that have been requested to be removed but have not been killed yet
+ private val executorsPendingToRemove = new mutable.HashSet[String]
+
+ // All known executors
+ private val executorIds = new mutable.HashSet[String]
+
+ // A timestamp of when an addition should be triggered, or NOT_SET if it is not set
+ // This is set when pending tasks are added but not scheduled yet
+ private var addTime: Long = NOT_SET
+
+ // A timestamp for each executor of when the executor should be removed, indexed by the ID
+ // This is set when an executor is no longer running a task, or when it first registers
+ private val removeTimes = new mutable.HashMap[String, Long]
+
+ // Polling loop interval (ms)
+ private val intervalMillis: Long = 100
+
+ // Clock used to schedule when executors should be added and removed
+ private var clock: Clock = new RealClock
+
+ // Listener for Spark events that impact the allocation policy
+ private val listener = new ExecutorAllocationListener
+
+ /**
+ * Verify that the settings specified through the config are valid.
+ * If not, throw an appropriate exception.
+ */
+ private def validateSettings(): Unit = {
+ if (minNumExecutors < 0 || maxNumExecutors < 0) {
+ throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be positive!")
+ }
+ if (maxNumExecutors == 0) {
+ throw new SparkException("spark.dynamicAllocation.maxExecutors cannot be 0!")
+ }
+ if (minNumExecutors > maxNumExecutors) {
+ throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " +
+ s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!")
+ }
+ if (schedulerBacklogTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!")
+ }
+ if (sustainedSchedulerBacklogTimeout <= 0) {
+ throw new SparkException(
+ "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!")
+ }
+ if (executorIdleTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!")
+ }
+ // Require external shuffle service for dynamic allocation
+ // Otherwise, we may lose shuffle files when killing executors
+ if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) {
+ throw new SparkException("Dynamic allocation of executors requires the external " +
+ "shuffle service. You may enable this through spark.shuffle.service.enabled.")
+ }
+ if (tasksPerExecutor == 0) {
+ throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.")
+ }
+ }
+
+ /**
+ * Use a different clock for this allocation manager. This is mainly used for testing.
+ */
+ def setClock(newClock: Clock): Unit = {
+ clock = newClock
+ }
+
+ /**
+ * Register for scheduler callbacks to decide when to add and remove executors.
+ */
+ def start(): Unit = {
+ listenerBus.addListener(listener)
+ startPolling()
+ }
+
+ /**
+ * Start the main polling thread that keeps track of when to add and remove executors.
+ */
+ private def startPolling(): Unit = {
+ val t = new Thread {
+ override def run(): Unit = {
+ while (true) {
+ try {
+ schedule()
+ } catch {
+ case e: Exception => logError("Exception in dynamic executor allocation thread!", e)
+ }
+ Thread.sleep(intervalMillis)
+ }
+ }
+ }
+ t.setName("spark-dynamic-executor-allocation")
+ t.setDaemon(true)
+ t.start()
+ }
+
+ /**
+ * If the add time has expired, request new executors and refresh the add time.
+ * If the remove time for an existing executor has expired, kill the executor.
+ * This is factored out into its own method for testing.
+ */
+ private def schedule(): Unit = synchronized {
+ val now = clock.getTimeMillis
+ if (addTime != NOT_SET && now >= addTime) {
+ addExecutors()
+ logDebug(s"Starting timer to add more executors (to " +
+ s"expire in $sustainedSchedulerBacklogTimeout seconds)")
+ addTime += sustainedSchedulerBacklogTimeout * 1000
+ }
+
+ removeTimes.retain { case (executorId, expireTime) =>
+ val expired = now >= expireTime
+ if (expired) {
+ removeExecutor(executorId)
+ }
+ !expired
+ }
+ }
+
+ /**
+ * Request a number of executors from the cluster manager.
+ * If the cap on the number of executors is reached, give up and reset the
+ * number of executors to add next round instead of continuing to double it.
+ * Return the number actually requested.
+ */
+ private def addExecutors(): Int = synchronized {
+ // Do not request more executors if we have already reached the upper bound
+ val numExistingExecutors = executorIds.size + numExecutorsPending
+ if (numExistingExecutors >= maxNumExecutors) {
+ logDebug(s"Not adding executors because there are already ${executorIds.size} " +
+ s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)")
+ numExecutorsToAdd = 1
+ return 0
+ }
+
+ // The number of executors needed to satisfy all pending tasks is the number of tasks pending
+ // divided by the number of tasks each executor can fit, rounded up.
+ val maxNumExecutorsPending =
+ (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor
+ if (numExecutorsPending >= maxNumExecutorsPending) {
+ logDebug(s"Not adding executors because there are already $numExecutorsPending " +
+ s"pending and pending tasks could only fill $maxNumExecutorsPending")
+ numExecutorsToAdd = 1
+ return 0
+ }
+
+ // It's never useful to request more executors than could satisfy all the pending tasks, so
+ // cap request at that amount.
+ // Also cap request with respect to the configured upper bound.
+ val maxNumExecutorsToAdd = math.min(
+ maxNumExecutorsPending - numExecutorsPending,
+ maxNumExecutors - numExistingExecutors)
+ assert(maxNumExecutorsToAdd > 0)
+
+ val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd)
+
+ val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd
+ val addRequestAcknowledged = testing || client.requestExecutors(actualNumExecutorsToAdd)
+ if (addRequestAcknowledged) {
+ logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " +
+ s"tasks are backlogged (new desired total will be $newTotalExecutors)")
+ numExecutorsToAdd =
+ if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1
+ numExecutorsPending += actualNumExecutorsToAdd
+ actualNumExecutorsToAdd
+ } else {
+ logWarning(s"Unable to reach the cluster manager " +
+ s"to request $actualNumExecutorsToAdd executors!")
+ 0
+ }
+ }
+
+ /**
+ * Request the cluster manager to remove the given executor.
+ * Return whether the request is received.
+ */
+ private def removeExecutor(executorId: String): Boolean = synchronized {
+ // Do not kill the executor if we are not aware of it (should never happen)
+ if (!executorIds.contains(executorId)) {
+ logWarning(s"Attempted to remove unknown executor $executorId!")
+ return false
+ }
+
+ // Do not kill the executor again if it is already pending to be killed (should never happen)
+ if (executorsPendingToRemove.contains(executorId)) {
+ logWarning(s"Attempted to remove executor $executorId " +
+ s"when it is already pending to be removed!")
+ return false
+ }
+
+ // Do not kill the executor if we have already reached the lower bound
+ val numExistingExecutors = executorIds.size - executorsPendingToRemove.size
+ if (numExistingExecutors - 1 < minNumExecutors) {
+ logDebug(s"Not removing idle executor $executorId because there are only " +
+ s"$numExistingExecutors executor(s) left (limit $minNumExecutors)")
+ return false
+ }
+
+ // Send a request to the backend to kill this executor
+ val removeRequestAcknowledged = testing || client.killExecutor(executorId)
+ if (removeRequestAcknowledged) {
+ logInfo(s"Removing executor $executorId because it has been idle for " +
+ s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})")
+ executorsPendingToRemove.add(executorId)
+ true
+ } else {
+ logWarning(s"Unable to reach the cluster manager to kill executor $executorId!")
+ false
+ }
+ }
+
+ /**
+ * Callback invoked when the specified executor has been added.
+ */
+ private def onExecutorAdded(executorId: String): Unit = synchronized {
+ if (!executorIds.contains(executorId)) {
+ executorIds.add(executorId)
+ // If an executor (call this executor X) is not removed because the lower bound
+ // has been reached, it will no longer be marked as idle. When new executors join,
+ // however, we are no longer at the lower bound, and so we must mark executor X
+ // as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951)
+ executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle)
+ logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})")
+ if (numExecutorsPending > 0) {
+ numExecutorsPending -= 1
+ logDebug(s"Decremented number of pending executors ($numExecutorsPending left)")
+ }
+ } else {
+ logWarning(s"Duplicate executor $executorId has registered")
+ }
+ }
+
+ /**
+ * Callback invoked when the specified executor has been removed.
+ */
+ private def onExecutorRemoved(executorId: String): Unit = synchronized {
+ if (executorIds.contains(executorId)) {
+ executorIds.remove(executorId)
+ removeTimes.remove(executorId)
+ logInfo(s"Existing executor $executorId has been removed (new total is ${executorIds.size})")
+ if (executorsPendingToRemove.contains(executorId)) {
+ executorsPendingToRemove.remove(executorId)
+ logDebug(s"Executor $executorId is no longer pending to " +
+ s"be removed (${executorsPendingToRemove.size} left)")
+ }
+ } else {
+ logWarning(s"Unknown executor $executorId has been removed!")
+ }
+ }
+
+ /**
+ * Callback invoked when the scheduler receives new pending tasks.
+ * This sets a time in the future that decides when executors should be added
+ * if it is not already set.
+ */
+ private def onSchedulerBacklogged(): Unit = synchronized {
+ if (addTime == NOT_SET) {
+ logDebug(s"Starting timer to add executors because pending tasks " +
+ s"are building up (to expire in $schedulerBacklogTimeout seconds)")
+ addTime = clock.getTimeMillis + schedulerBacklogTimeout * 1000
+ }
+ }
+
+ /**
+ * Callback invoked when the scheduler queue is drained.
+ * This resets all variables used for adding executors.
+ */
+ private def onSchedulerQueueEmpty(): Unit = synchronized {
+ logDebug(s"Clearing timer to add executors because there are no more pending tasks")
+ addTime = NOT_SET
+ numExecutorsToAdd = 1
+ }
+
+ /**
+ * Callback invoked when the specified executor is no longer running any tasks.
+ * This sets a time in the future that decides when this executor should be removed if
+ * the executor is not already marked as idle.
+ */
+ private def onExecutorIdle(executorId: String): Unit = synchronized {
+ if (executorIds.contains(executorId)) {
+ if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
+ logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
+ s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)")
+ removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000
+ }
+ } else {
+ logWarning(s"Attempted to mark unknown executor $executorId idle")
+ }
+ }
+
+ /**
+ * Callback invoked when the specified executor is now running a task.
+ * This resets all variables used for removing this executor.
+ */
+ private def onExecutorBusy(executorId: String): Unit = synchronized {
+ logDebug(s"Clearing idle timer for $executorId because it is now running a task")
+ removeTimes.remove(executorId)
+ }
+
+ /**
+ * A listener that notifies the given allocation manager of when to add and remove executors.
+ *
+ * This class is intentionally conservative in its assumptions about the relative ordering
+ * and consistency of events returned by the listener. For simplicity, it does not account
+ * for speculated tasks.
+ */
+ private class ExecutorAllocationListener extends SparkListener {
+
+ private val stageIdToNumTasks = new mutable.HashMap[Int, Int]
+ private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
+ private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]]
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+ val stageId = stageSubmitted.stageInfo.stageId
+ val numTasks = stageSubmitted.stageInfo.numTasks
+ allocationManager.synchronized {
+ stageIdToNumTasks(stageId) = numTasks
+ allocationManager.onSchedulerBacklogged()
+ }
+ }
+
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+ val stageId = stageCompleted.stageInfo.stageId
+ allocationManager.synchronized {
+ stageIdToNumTasks -= stageId
+ stageIdToTaskIndices -= stageId
+
+ // If this is the last stage with pending tasks, mark the scheduler queue as empty
+ // This is needed in case the stage is aborted for any reason
+ if (stageIdToNumTasks.isEmpty) {
+ allocationManager.onSchedulerQueueEmpty()
+ }
+ }
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
+ val stageId = taskStart.stageId
+ val taskId = taskStart.taskInfo.taskId
+ val taskIndex = taskStart.taskInfo.index
+ val executorId = taskStart.taskInfo.executorId
+
+ allocationManager.synchronized {
+ // This guards against the race condition in which the `SparkListenerTaskStart`
+ // event is posted before the `SparkListenerBlockManagerAdded` event, which is
+ // possible because these events are posted in different threads. (see SPARK-4951)
+ if (!allocationManager.executorIds.contains(executorId)) {
+ allocationManager.onExecutorAdded(executorId)
+ }
+
+ // If this is the last pending task, mark the scheduler queue as empty
+ stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
+ val numTasksScheduled = stageIdToTaskIndices(stageId).size
+ val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
+ if (numTasksScheduled == numTasksTotal) {
+ // No more pending tasks for this stage
+ stageIdToNumTasks -= stageId
+ if (stageIdToNumTasks.isEmpty) {
+ allocationManager.onSchedulerQueueEmpty()
+ }
+ }
+
+ // Mark the executor on which this task is scheduled as busy
+ executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId
+ allocationManager.onExecutorBusy(executorId)
+ }
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
+ val executorId = taskEnd.taskInfo.executorId
+ val taskId = taskEnd.taskInfo.taskId
+ allocationManager.synchronized {
+ // If the executor is no longer running scheduled any tasks, mark it as idle
+ if (executorIdToTaskIds.contains(executorId)) {
+ executorIdToTaskIds(executorId) -= taskId
+ if (executorIdToTaskIds(executorId).isEmpty) {
+ executorIdToTaskIds -= executorId
+ allocationManager.onExecutorIdle(executorId)
+ }
+ }
+ }
+ }
+
+ override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = {
+ val executorId = blockManagerAdded.blockManagerId.executorId
+ if (executorId != SparkContext.DRIVER_IDENTIFIER) {
+ // This guards against the race condition in which the `SparkListenerTaskStart`
+ // event is posted before the `SparkListenerBlockManagerAdded` event, which is
+ // possible because these events are posted in different threads. (see SPARK-4951)
+ if (!allocationManager.executorIds.contains(executorId)) {
+ allocationManager.onExecutorAdded(executorId)
+ }
+ }
+ }
+
+ override def onBlockManagerRemoved(
+ blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = {
+ allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId)
+ }
+
+ /**
+ * An estimate of the total number of pending tasks remaining for currently running stages. Does
+ * not account for tasks which may have failed and been resubmitted.
+ *
+ * Note: This is not thread-safe without the caller owning the `allocationManager` lock.
+ */
+ def totalPendingTasks(): Int = {
+ stageIdToNumTasks.map { case (stageId, numTasks) =>
+ numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0)
+ }.sum
+ }
+
+ /**
+ * Return true if an executor is not currently running a task, and false otherwise.
+ *
+ * Note: This is not thread-safe without the caller owning the `allocationManager` lock.
+ */
+ def isExecutorIdle(executorId: String): Boolean = {
+ !executorIdToTaskIds.contains(executorId)
+ }
+ }
+
+}
+
+private object ExecutorAllocationManager {
+ val NOT_SET = Long.MaxValue
+}
+
+/**
+ * An abstract clock for measuring elapsed time.
+ */
+private trait Clock {
+ def getTimeMillis: Long
+}
+
+/**
+ * A clock backed by a monotonically increasing time source.
+ * The time returned by this clock does not correspond to any notion of wall-clock time.
+ */
+private class RealClock extends Clock {
+ override def getTimeMillis: Long = System.nanoTime / (1000 * 1000)
+}
+
+/**
+ * A clock that allows the caller to customize the time.
+ * This is used mainly for testing.
+ */
+private class TestClock(startTimeMillis: Long) extends Clock {
+ private var time: Long = startTimeMillis
+ override def getTimeMillis: Long = time
+ def tick(ms: Long): Unit = { time += ms }
+}
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index d5c8f9d76c476..e97a7375a267b 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -210,7 +210,11 @@ class ComplexFutureAction[T] extends FutureAction[T] {
} catch {
case e: Exception => p.failure(e)
} finally {
- thread = null
+ // This lock guarantees when calling `thread.interrupt()` in `cancel`,
+ // thread won't be set to null.
+ ComplexFutureAction.this.synchronized {
+ thread = null
+ }
}
}
this
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index edc3889c9ae51..3f33332a81eaf 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -24,6 +24,7 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
private[spark] class HttpFileServer(
+ conf: SparkConf,
securityManager: SecurityManager,
requestedPort: Int = 0)
extends Logging {
@@ -35,13 +36,13 @@ private[spark] class HttpFileServer(
var serverUri : String = null
def initialize() {
- baseDir = Utils.createTempDir()
+ baseDir = Utils.createTempDir(Utils.getLocalDir(conf), "httpd")
fileDir = new File(baseDir, "files")
jarDir = new File(baseDir, "jars")
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server")
+ httpServer = new HttpServer(conf, baseDir, securityManager, requestedPort, "HTTP file server")
httpServer.start()
serverUri = httpServer.uri
logDebug("HTTP file server started at: " + serverUri)
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 912558d0cab7d..09a9ccc226721 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -19,6 +19,7 @@ package org.apache.spark
import java.io.File
+import org.eclipse.jetty.server.ssl.SslSocketConnector
import org.eclipse.jetty.util.security.{Constraint, Password}
import org.eclipse.jetty.security.authentication.DigestAuthenticator
import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService}
@@ -42,6 +43,7 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* around a Jetty server.
*/
private[spark] class HttpServer(
+ conf: SparkConf,
resourceBase: File,
securityManager: SecurityManager,
requestedPort: Int = 0,
@@ -57,7 +59,7 @@ private[spark] class HttpServer(
} else {
logInfo("Starting HTTP Server")
val (actualServer, actualPort) =
- Utils.startServiceOnPort[Server](requestedPort, doStart, serverName)
+ Utils.startServiceOnPort[Server](requestedPort, doStart, conf, serverName)
server = actualServer
port = actualPort
}
@@ -71,7 +73,10 @@ private[spark] class HttpServer(
*/
private def doStart(startPort: Int): (Server, Int) = {
val server = new Server()
- val connector = new SocketConnector
+
+ val connector = securityManager.fileServerSSLOptions.createJettySslContextFactory()
+ .map(new SslSocketConnector(_)).getOrElse(new SocketConnector)
+
connector.setMaxIdleTime(60 * 1000)
connector.setSoLingerTime(-1)
connector.setPort(startPort)
@@ -148,13 +153,14 @@ private[spark] class HttpServer(
}
/**
- * Get the URI of this HTTP server (http://host:port)
+ * Get the URI of this HTTP server (http://host:port or https://host:port)
*/
def uri: String = {
if (server == null) {
throw new ServerStateException("Server is not started")
} else {
- "http://" + Utils.localIpAddress + ":" + port
+ val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http"
+ s"$scheme://${Utils.localIpAddress}:$port"
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index d4f2624061e35..419d093d55643 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -118,15 +118,17 @@ trait Logging {
// org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently
// org.apache.logging.slf4j.Log4jLoggerFactory
val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass)
- val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
- if (!log4j12Initialized && usingLog4j12) {
- val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
- case Some(url) =>
- PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
- case None =>
- System.err.println(s"Spark was unable to load $defaultLogProps")
+ if (usingLog4j12) {
+ val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
+ if (!log4j12Initialized) {
+ val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ case None =>
+ System.err.println(s"Spark was unable to load $defaultLogProps")
+ }
}
}
Logging.initialized = true
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 4cb0bd4142435..6e4edc7c80d7a 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -72,20 +72,22 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
/**
* Class that keeps track of the location of the map output of
* a stage. This is abstract because different versions of MapOutputTracker
- * (driver and worker) use different HashMap to store its metadata.
+ * (driver and executor) use different HashMap to store its metadata.
*/
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
private val timeout = AkkaUtils.askTimeout(conf)
+ private val retryAttempts = AkkaUtils.numRetries(conf)
+ private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
/** Set to the MapOutputTrackerActor living on the driver. */
var trackerActor: ActorRef = _
/**
- * This HashMap has different behavior for the master and the workers.
+ * This HashMap has different behavior for the driver and the executors.
*
- * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks.
- * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the
- * master's corresponding HashMap.
+ * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks.
+ * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the
+ * driver's corresponding HashMap.
*
* Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a
* thread-safe map.
@@ -99,7 +101,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
protected var epoch: Long = 0
protected val epochLock = new AnyRef
- /** Remembers which map output locations are currently being fetched on a worker. */
+ /** Remembers which map output locations are currently being fetched on an executor. */
private val fetching = new HashSet[Int]
/**
@@ -108,8 +110,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
*/
protected def askTracker(message: Any): Any = {
try {
- val future = trackerActor.ask(message)(timeout)
- Await.result(future, timeout)
+ AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
@@ -136,14 +137,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
- if (fetching.contains(shuffleId)) {
- // Someone else is fetching it; wait for them to be done
- while (fetching.contains(shuffleId)) {
- try {
- fetching.wait()
- } catch {
- case e: InterruptedException =>
- }
+ // Someone else is fetching it; wait for them to be done
+ while (fetching.contains(shuffleId)) {
+ try {
+ fetching.wait()
+ } catch {
+ case e: InterruptedException =>
}
}
@@ -178,6 +177,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
} else {
+ logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
@@ -197,8 +197,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
/**
* Called from executors to update the epoch number, potentially clearing old outputs
- * because of a fetch failure. Each worker task calls this with the latest epoch
- * number on the master at the time it was created.
+ * because of a fetch failure. Each executor task calls this with the latest epoch
+ * number on the driver at the time it was created.
*/
def updateEpoch(newEpoch: Long) {
epochLock.synchronized {
@@ -230,7 +230,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
private var cacheEpoch = epoch
/**
- * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master,
+ * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the driver,
* so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set).
* Other than these two scenarios, nothing should be dropped from this HashMap.
*/
@@ -340,7 +340,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
/**
- * MapOutputTracker for the workers, which fetches map output information from the driver's
+ * MapOutputTracker for the executors, which fetches map output information from the driver's
* MapOutputTrackerMaster.
*/
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
@@ -348,7 +348,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
new ConcurrentHashMap[Int, Array[MapStatus]]
}
-private[spark] object MapOutputTracker {
+private[spark] object MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
@@ -381,6 +381,7 @@ private[spark] object MapOutputTracker {
statuses.map {
status =>
if (status == null) {
+ logError("Missing an output location for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala
index 27892dbd2a0bc..dd3f28e4197e3 100644
--- a/core/src/main/scala/org/apache/spark/Partition.scala
+++ b/core/src/main/scala/org/apache/spark/Partition.scala
@@ -18,11 +18,11 @@
package org.apache.spark
/**
- * A partition of an RDD.
+ * An identifier for a partition in an RDD.
*/
trait Partition extends Serializable {
/**
- * Get the split's index within its parent RDD
+ * Get the partition's index within its parent RDD
*/
def index: Int
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 37053bb6f37ad..e53a78ead2c0e 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -204,7 +204,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
@throws(classOf[IOException])
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => out.defaultWriteObject()
@@ -222,7 +222,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
@throws(classOf[IOException])
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => in.defaultReadObject()
diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala
new file mode 100644
index 0000000000000..2cdc167f85af0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.File
+
+import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory}
+import org.eclipse.jetty.util.ssl.SslContextFactory
+
+/**
+ * SSLOptions class is a common container for SSL configuration options. It offers methods to
+ * generate specific objects to configure SSL for different communication protocols.
+ *
+ * SSLOptions is intended to provide the maximum common set of SSL settings, which are supported
+ * by the protocol, which it can generate the configuration for. Since Akka doesn't support client
+ * authentication with SSL, SSLOptions cannot support it either.
+ *
+ * @param enabled enables or disables SSL; if it is set to false, the rest of the
+ * settings are disregarded
+ * @param keyStore a path to the key-store file
+ * @param keyStorePassword a password to access the key-store file
+ * @param keyPassword a password to access the private key in the key-store
+ * @param trustStore a path to the trust-store file
+ * @param trustStorePassword a password to access the trust-store file
+ * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java
+ * @param enabledAlgorithms a set of encryption algorithms to use
+ */
+private[spark] case class SSLOptions(
+ enabled: Boolean = false,
+ keyStore: Option[File] = None,
+ keyStorePassword: Option[String] = None,
+ keyPassword: Option[String] = None,
+ trustStore: Option[File] = None,
+ trustStorePassword: Option[String] = None,
+ protocol: Option[String] = None,
+ enabledAlgorithms: Set[String] = Set.empty) {
+
+ /**
+ * Creates a Jetty SSL context factory according to the SSL settings represented by this object.
+ */
+ def createJettySslContextFactory(): Option[SslContextFactory] = {
+ if (enabled) {
+ val sslContextFactory = new SslContextFactory()
+
+ keyStore.foreach(file => sslContextFactory.setKeyStorePath(file.getAbsolutePath))
+ trustStore.foreach(file => sslContextFactory.setTrustStore(file.getAbsolutePath))
+ keyStorePassword.foreach(sslContextFactory.setKeyStorePassword)
+ trustStorePassword.foreach(sslContextFactory.setTrustStorePassword)
+ keyPassword.foreach(sslContextFactory.setKeyManagerPassword)
+ protocol.foreach(sslContextFactory.setProtocol)
+ sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*)
+
+ Some(sslContextFactory)
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Creates an Akka configuration object which contains all the SSL settings represented by this
+ * object. It can be used then to compose the ultimate Akka configuration.
+ */
+ def createAkkaConfig: Option[Config] = {
+ import scala.collection.JavaConversions._
+ if (enabled) {
+ Some(ConfigFactory.empty()
+ .withValue("akka.remote.netty.tcp.security.key-store",
+ ConfigValueFactory.fromAnyRef(keyStore.map(_.getAbsolutePath).getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.key-store-password",
+ ConfigValueFactory.fromAnyRef(keyStorePassword.getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.trust-store",
+ ConfigValueFactory.fromAnyRef(trustStore.map(_.getAbsolutePath).getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.trust-store-password",
+ ConfigValueFactory.fromAnyRef(trustStorePassword.getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.key-password",
+ ConfigValueFactory.fromAnyRef(keyPassword.getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.random-number-generator",
+ ConfigValueFactory.fromAnyRef(""))
+ .withValue("akka.remote.netty.tcp.security.protocol",
+ ConfigValueFactory.fromAnyRef(protocol.getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.enabled-algorithms",
+ ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq))
+ .withValue("akka.remote.netty.tcp.enable-ssl",
+ ConfigValueFactory.fromAnyRef(true)))
+ } else {
+ None
+ }
+ }
+
+ /** Returns a string representation of this SSLOptions with all the passwords masked. */
+ override def toString: String = s"SSLOptions{enabled=$enabled, " +
+ s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " +
+ s"trustStore=$trustStore, trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " +
+ s"protocol=$protocol, enabledAlgorithms=$enabledAlgorithms}"
+
+}
+
+private[spark] object SSLOptions extends Logging {
+
+ /** Resolves SSLOptions settings from a given Spark configuration object at a given namespace.
+ *
+ * The following settings are allowed:
+ * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively
+ * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory
+ * $ - `[ns].keyStorePassword` - a password to the key-store file
+ * $ - `[ns].keyPassword` - a password to the private key
+ * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current
+ * directory
+ * $ - `[ns].trustStorePassword` - a password to the trust-store file
+ * $ - `[ns].protocol` - a protocol name supported by a particular Java version
+ * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers
+ *
+ * For a list of protocols and ciphers supported by particular Java versions, you may go to
+ * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle
+ * blog page]].
+ *
+ * You can optionally specify the default configuration. If you do, for each setting which is
+ * missing in SparkConf, the corresponding setting is used from the default configuration.
+ *
+ * @param conf Spark configuration object where the settings are collected from
+ * @param ns the namespace name
+ * @param defaults the default configuration
+ * @return [[org.apache.spark.SSLOptions]] object
+ */
+ def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = {
+ val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled))
+
+ val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_))
+ .orElse(defaults.flatMap(_.keyStore))
+
+ val keyStorePassword = conf.getOption(s"$ns.keyStorePassword")
+ .orElse(defaults.flatMap(_.keyStorePassword))
+
+ val keyPassword = conf.getOption(s"$ns.keyPassword")
+ .orElse(defaults.flatMap(_.keyPassword))
+
+ val trustStore = conf.getOption(s"$ns.trustStore").map(new File(_))
+ .orElse(defaults.flatMap(_.trustStore))
+
+ val trustStorePassword = conf.getOption(s"$ns.trustStorePassword")
+ .orElse(defaults.flatMap(_.trustStorePassword))
+
+ val protocol = conf.getOption(s"$ns.protocol")
+ .orElse(defaults.flatMap(_.protocol))
+
+ val enabledAlgorithms = conf.getOption(s"$ns.enabledAlgorithms")
+ .map(_.split(",").map(_.trim).filter(_.nonEmpty).toSet)
+ .orElse(defaults.map(_.enabledAlgorithms))
+ .getOrElse(Set.empty)
+
+ new SSLOptions(
+ enabled,
+ keyStore,
+ keyStorePassword,
+ keyPassword,
+ trustStore,
+ trustStorePassword,
+ protocol,
+ enabledAlgorithms)
+ }
+
+}
+
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 0e0f1a7b2377e..88d35a4bacc6e 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -18,10 +18,15 @@
package org.apache.spark
import java.net.{Authenticator, PasswordAuthentication}
+import java.security.KeyStore
+import java.security.cert.X509Certificate
+import javax.net.ssl._
+import com.google.common.io.Files
import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.network.sasl.SecretKeyHolder
/**
* Spark class responsible for security.
@@ -54,7 +59,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* Spark also has a set of admin acls (`spark.admin.acls`) which is a set of users/administrators
* who always have permission to view or modify the Spark application.
*
- * Spark does not currently support encryption after authentication.
+ * Starting from version 1.3, Spark has partial support for encrypted connections with SSL.
*
* At this point spark has multiple communication protocols that need to be secured and
* different underlying mechanisms are used depending on the protocol:
@@ -66,8 +71,9 @@ import org.apache.spark.deploy.SparkHadoopUtil
* to connect to the server. There is no control of the underlying
* authentication mechanism so its not clear if the password is passed in
* plaintext or uses DIGEST-MD5 or some other mechanism.
- * Akka also has an option to turn on SSL, this option is not currently supported
- * but we could add a configuration option in the future.
+ *
+ * Akka also has an option to turn on SSL, this option is currently supported (see
+ * the details below).
*
* - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty
* for the HttpServer. Jetty supports multiple authentication mechanisms -
@@ -76,15 +82,16 @@ import org.apache.spark.deploy.SparkHadoopUtil
* to authenticate using DIGEST-MD5 via a single user and the shared secret.
* Since we are using DIGEST-MD5, the shared secret is not passed on the wire
* in plaintext.
- * We currently do not support SSL (https), but Jetty can be configured to use it
- * so we could add a configuration option for this in the future.
+ *
+ * We currently support SSL (https) for this communication protocol (see the details
+ * below).
*
* The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5.
* Any clients must specify the user and password. There is a default
* Authenticator installed in the SecurityManager to how it does the authentication
* and in this case gets the user name and password from the request.
*
- * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously
* exchange messages. For this we use the Java SASL
* (Simple Authentication and Security Layer) API and again use DIGEST-MD5
* as the authentication mechanism. This means the shared secret is not passed
@@ -92,31 +99,35 @@ import org.apache.spark.deploy.SparkHadoopUtil
* Note that SASL is pluggable as to what mechanism it uses. We currently use
* DIGEST-MD5 but this could be changed to use Kerberos or other in the future.
* Spark currently supports "auth" for the quality of protection, which means
- * the connection is not supporting integrity or privacy protection (encryption)
+ * the connection does not support integrity or privacy protection (encryption)
* after authentication. SASL also supports "auth-int" and "auth-conf" which
- * SPARK could be support in the future to allow the user to specify the quality
+ * SPARK could support in the future to allow the user to specify the quality
* of protection they want. If we support those, the messages will also have to
* be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
*
- * Since the connectionManager does asynchronous messages passing, the SASL
+ * Since the NioBlockTransferService does asynchronous messages passing, the SASL
* authentication is a bit more complex. A ConnectionManager can be both a client
- * and a Server, so for a particular connection is has to determine what to do.
+ * and a Server, so for a particular connection it has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
* match up incoming messages with connections waiting for authentication.
- * The ConnectionManager tracks all the sendingConnections using the ConnectionId
- * and waits for the response from the server and does the handshake before sending
+ * The ConnectionManager tracks all the sendingConnections using the ConnectionId,
+ * waits for the response from the server, and does the handshake before sending
* the real message.
*
+ * The NettyBlockTransferService ensures that SASL authentication is performed
+ * synchronously prior to any other communication on a connection. This is done in
+ * SaslClientBootstrap on the client side and SaslRpcHandler on the server side.
+ *
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
- * properly. For non-Yarn deployments, users can write a filter to go through a
- * companies normal login service. If an authentication filter is in place then the
+ * properly. For non-Yarn deployments, users can write a filter to go through their
+ * organization's normal login service. If an authentication filter is in place then the
* SparkUI can be configured to check the logged in user against the list of users who
* have view acls to see if that user is authorized.
* The filters can also be used for many different purposes. For instance filters
* could be used for logging, encryption, or compression.
*
- * The exact mechanisms used to generate/distributed the shared secret is deployment specific.
+ * The exact mechanisms used to generate/distribute the shared secret are deployment-specific.
*
* For Yarn deployments, the secret is automatically generated using the Akka remote
* Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed
@@ -133,21 +144,52 @@ import org.apache.spark.deploy.SparkHadoopUtil
* All the nodes (Master and Workers) and the applications need to have the same shared secret.
* This again is not ideal as one user could potentially affect another users application.
* This should be enhanced in the future to provide better protection.
- * If the UI needs to be secured the user needs to install a javax servlet filter to do the
+ * If the UI needs to be secure, the user needs to install a javax servlet filter to do the
* authentication. Spark will then use that user to compare against the view acls to do
* authorization. If not filter is in place the user is generally null and no authorization
* can take place.
+ *
+ * Connection encryption (SSL) configuration is organized hierarchically. The user can configure
+ * the default SSL settings which will be used for all the supported communication protocols unless
+ * they are overwritten by protocol specific settings. This way the user can easily provide the
+ * common settings for all the protocols without disabling the ability to configure each one
+ * individually.
+ *
+ * All the SSL settings like `spark.ssl.xxx` where `xxx` is a particular configuration property,
+ * denote the global configuration for all the supported protocols. In order to override the global
+ * configuration for the particular protocol, the properties must be overwritten in the
+ * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global
+ * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be either `akka` for
+ * Akka based connections or `fs` for broadcast and file server.
+ *
+ * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of
+ * options that can be specified.
+ *
+ * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions
+ * object parses Spark configuration at a given namespace and builds the common representation
+ * of SSL settings. SSLOptions is then used to provide protocol-specific configuration like
+ * TypeSafe configuration for Akka or SSLContextFactory for Jetty.
+ *
+ * SSL must be configured on each node and configured for each component involved in
+ * communication using the particular protocol. In YARN clusters, the key-store can be prepared on
+ * the client side then distributed and used by the executors as the part of the application
+ * (YARN allows the user to deploy files before the application is started).
+ * In standalone deployment, the user needs to provide key-stores and configuration
+ * options for master and workers. In this mode, the user may allow the executors to use the SSL
+ * settings inherited from the worker which spawned that executor. It can be accomplished by
+ * setting `spark.ssl.useNodeLocalConf` to `true`.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+private[spark] class SecurityManager(sparkConf: SparkConf)
+ extends Logging with SecretKeyHolder {
// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"
private val authOn = sparkConf.getBoolean("spark.authenticate", false)
// keep spark.ui.acls.enable for backwards compatibility with 1.0
- private var aclsOn = sparkConf.getOption("spark.acls.enable").getOrElse(
- sparkConf.get("spark.ui.acls.enable", "false")).toBoolean
+ private var aclsOn =
+ sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false))
// admin acls should be set before view or modify acls
private var adminAcls: Set[String] =
@@ -191,6 +233,57 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
)
}
+ // the default SSL configuration - it will be used by all communication layers unless overwritten
+ private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None)
+
+ // SSL configuration for different communication layers - they can override the default
+ // configuration at a specified namespace. The namespace *must* start with spark.ssl.
+ val fileServerSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.fs", Some(defaultSSLOptions))
+ val akkaSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.akka", Some(defaultSSLOptions))
+
+ logDebug(s"SSLConfiguration for file server: $fileServerSSLOptions")
+ logDebug(s"SSLConfiguration for Akka: $akkaSSLOptions")
+
+ val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) {
+ val trustStoreManagers =
+ for (trustStore <- fileServerSSLOptions.trustStore) yield {
+ val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream()
+
+ try {
+ val ks = KeyStore.getInstance(KeyStore.getDefaultType)
+ ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray)
+
+ val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
+ tmf.init(ks)
+ tmf.getTrustManagers
+ } finally {
+ input.close()
+ }
+ }
+
+ lazy val credulousTrustStoreManagers = Array({
+ logWarning("Using 'accept-all' trust manager for SSL connections.")
+ new X509TrustManager {
+ override def getAcceptedIssuers: Array[X509Certificate] = null
+
+ override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {}
+
+ override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {}
+ }: TrustManager
+ })
+
+ val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.getOrElse("Default"))
+ sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null)
+
+ val hostVerifier = new HostnameVerifier {
+ override def verify(s: String, sslSession: SSLSession): Boolean = true
+ }
+
+ (Some(sslContext.getSocketFactory), Some(hostVerifier))
+ } else {
+ (None, None)
+ }
+
/**
* Split a comma separated String, filter out any empty items, and return a Set of strings
*/
@@ -337,4 +430,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
* @return the secret key as a String if authentication is enabled, otherwise returns null
*/
def getSecretKey(): String = secretKey
+
+ // Default SecurityManager only has a single secret key, so ignore appId.
+ override def getSaslUser(appId: String): String = getSaslUser()
+ override def getSecretKey(appId: String): String = getSecretKey()
}
diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
index e50b9ac2291f9..55cb25946c2ad 100644
--- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala
+++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
@@ -24,18 +24,19 @@ import org.apache.hadoop.io.ObjectWritable
import org.apache.hadoop.io.Writable
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
@DeveloperApi
class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable {
def value = t
override def toString = t.toString
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
out.defaultWriteObject()
new ObjectWritable(t).write(out)
}
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
val ow = new ObjectWritable()
ow.setConf(new Configuration())
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index dbbcc23305c50..13aa9960ac33a 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -17,9 +17,13 @@
package org.apache.spark
+import java.util.concurrent.ConcurrentHashMap
+
import scala.collection.JavaConverters._
-import scala.collection.mutable.{HashMap, LinkedHashSet}
+import scala.collection.mutable.LinkedHashSet
+
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.util.Utils
/**
* Configuration for a Spark application. Used to set various Spark parameters as key-value pairs.
@@ -46,12 +50,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Create a SparkConf that loads defaults from system properties and the classpath */
def this() = this(true)
- private[spark] val settings = new HashMap[String, String]()
+ private val settings = new ConcurrentHashMap[String, String]()
if (loadDefaults) {
// Load any spark.* system properties
- for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) {
- settings(k) = v
+ for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) {
+ set(key, value)
}
}
@@ -61,9 +65,9 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
throw new NullPointerException("null key")
}
if (value == null) {
- throw new NullPointerException("null value")
+ throw new NullPointerException("null value for " + key)
}
- settings(key) = value
+ settings.put(key, value)
this
}
@@ -129,15 +133,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Set multiple parameters together */
def setAll(settings: Traversable[(String, String)]) = {
- this.settings ++= settings
+ this.settings.putAll(settings.toMap.asJava)
this
}
/** Set a parameter if it isn't already configured */
def setIfMissing(key: String, value: String): SparkConf = {
- if (!settings.contains(key)) {
- settings(key) = value
- }
+ settings.putIfAbsent(key, value)
this
}
@@ -163,21 +165,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Get a parameter; throws a NoSuchElementException if it's not set */
def get(key: String): String = {
- settings.getOrElse(key, throw new NoSuchElementException(key))
+ getOption(key).getOrElse(throw new NoSuchElementException(key))
}
/** Get a parameter, falling back to a default if not set */
def get(key: String, defaultValue: String): String = {
- settings.getOrElse(key, defaultValue)
+ getOption(key).getOrElse(defaultValue)
}
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
- settings.get(key)
+ Option(settings.get(key))
}
/** Get all parameters as a list of pairs */
- def getAll: Array[(String, String)] = settings.clone().toArray
+ def getAll: Array[(String, String)] = {
+ settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray
+ }
/** Get a parameter as an integer, falling back to a default if not set */
def getInt(key: String, defaultValue: Int): Int = {
@@ -217,12 +221,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
*/
getAll.filter { case (k, _) => isAkkaConf(k) }
+ /**
+ * Returns the Spark application id, valid in the Driver after TaskScheduler registration and
+ * from the start in the Executor.
+ */
+ def getAppId: String = get("spark.app.id")
+
/** Does the configuration contain a given parameter? */
- def contains(key: String): Boolean = settings.contains(key)
+ def contains(key: String): Boolean = settings.containsKey(key)
/** Copy this object */
override def clone: SparkConf = {
- new SparkConf(false).setAll(settings)
+ new SparkConf(false).setAll(getAll)
}
/**
@@ -234,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Checks for illegal or deprecated config settings. Throws an exception for the former. Not
* idempotent - may mutate this conf object to convert deprecated settings to supported ones. */
private[spark] def validateSettings() {
- if (settings.contains("spark.local.dir")) {
+ if (contains("spark.local.dir")) {
val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " +
"the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)."
logWarning(msg)
@@ -244,9 +254,22 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
val executorClasspathKey = "spark.executor.extraClassPath"
val driverOptsKey = "spark.driver.extraJavaOptions"
val driverClassPathKey = "spark.driver.extraClassPath"
+ val driverLibraryPathKey = "spark.driver.extraLibraryPath"
+
+ // Used by Yarn in 1.1 and before
+ sys.props.get("spark.driver.libraryPath").foreach { value =>
+ val warning =
+ s"""
+ |spark.driver.libraryPath was detected (set to '$value').
+ |This is deprecated in Spark 1.2+.
+ |
+ |Please instead use: $driverLibraryPathKey
+ """.stripMargin
+ logWarning(warning)
+ }
// Validate spark.executor.extraJavaOptions
- settings.get(executorOptsKey).map { javaOpts =>
+ getOption(executorOptsKey).map { javaOpts =>
if (javaOpts.contains("-Dspark")) {
val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " +
"Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit."
@@ -326,7 +349,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
* configuration out for debugging.
*/
def toDebugString: String = {
- settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
+ getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
}
}
@@ -347,11 +370,14 @@ private[spark] object SparkConf {
isAkkaConf(name) ||
name.startsWith("spark.akka") ||
name.startsWith("spark.auth") ||
+ name.startsWith("spark.ssl") ||
isSparkPortConf(name)
}
/**
- * Return whether the given config is a Spark port config.
+ * Return true if the given config matches either `spark.*.port` or `spark.port.*`.
*/
- def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port")
+ def isSparkPortConf(name: String): Boolean = {
+ (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.")
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index ac7935b8c231e..a7adddb6c83ec 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -20,10 +20,10 @@ package org.apache.spark
import scala.language.implicitConversions
import java.io._
+import java.lang.reflect.Constructor
import java.net.URI
-import java.util.Arrays
+import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
-import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
import scala.collection.JavaConversions._
@@ -42,7 +42,8 @@ import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
-import org.apache.spark.input.WholeTextFileInputFormat
+import org.apache.spark.executor.TriggerThreadDump
+import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
@@ -50,24 +51,49 @@ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkD
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage._
-import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
+import org.apache.spark.ui.{SparkUI, ConsoleProgressBar}
+import org.apache.spark.ui.jobs.JobProgressListener
+import org.apache.spark.util._
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
+ * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before
+ * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details.
+ *
* @param config a Spark Config object describing the application configuration. Any settings in
* this config overrides the default configs as well as system properties.
*/
+class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationClient {
+
+ // The call site where this SparkContext was constructed.
+ private val creationSite: CallSite = Utils.getCallSite()
-class SparkContext(config: SparkConf) extends Logging {
+ // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active
+ private val allowMultipleContexts: Boolean =
+ config.getBoolean("spark.driver.allowMultipleContexts", false)
+
+ // In order to prevent multiple SparkContexts from being active at the same time, mark this
+ // context as having started construction.
+ // NOTE: this must be placed at the beginning of the SparkContext constructor.
+ SparkContext.markPartiallyConstructed(this, allowMultipleContexts)
// This is used only by YARN for now, but should be relevant to other cluster types (Mesos,
// etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It
// contains a map from hostname to a list of input format splits on the host.
private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()
+ val startTime = System.currentTimeMillis()
+
+ @volatile private var stopped: Boolean = false
+
+ private def assertNotStopped(): Unit = {
+ if (stopped) {
+ throw new IllegalStateException("Cannot call methods on a stopped SparkContext")
+ }
+ }
+
/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
@@ -155,6 +181,9 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] def this(master: String, appName: String, sparkHome: String, jars: Seq[String]) =
this(master, appName, sparkHome, jars, Map(), Map())
+ // log out Spark Version in Spark driver log
+ logInfo(s"Running Spark version $SPARK_VERSION")
+
private[spark] val conf = config.clone()
conf.validateSettings()
@@ -209,16 +238,10 @@ class SparkContext(config: SparkConf) extends Logging {
// An asynchronous listener bus for Spark events
private[spark] val listenerBus = new LiveListenerBus
+ conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER)
+
// Create the Spark execution environment (cache, map output tracker, etc)
- conf.set("spark.executor.id", "driver")
- private[spark] val env = SparkEnv.create(
- conf,
- "",
- conf.get("spark.driver.host"),
- conf.get("spark.driver.port").toInt,
- isDriver = true,
- isLocal = isLocal,
- listenerBus = listenerBus)
+ private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
@@ -230,10 +253,24 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf)
- // Initialize the Spark UI, registering all associated listeners
+
+ private[spark] val jobProgressListener = new JobProgressListener(conf)
+ listenerBus.addListener(jobProgressListener)
+
+ val statusTracker = new SparkStatusTracker(this)
+
+ private[spark] val progressBar: Option[ConsoleProgressBar] =
+ if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) {
+ Some(new ConsoleProgressBar(this))
+ } else {
+ None
+ }
+
+ // Initialize the Spark UI
private[spark] val ui: Option[SparkUI] =
if (conf.getBoolean("spark.ui.enabled", true)) {
- Some(new SparkUI(this))
+ Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener,
+ env.securityManager,appName))
} else {
// For tests, do not enable the UI
None
@@ -246,8 +283,6 @@ class SparkContext(config: SparkConf) extends Logging {
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf)
- val startTime = System.currentTimeMillis()
-
// Add each JAR given through the constructor
if (jars != null) {
jars.foreach(addJar)
@@ -295,15 +330,21 @@ class SparkContext(config: SparkConf) extends Logging {
executorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler
- private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master)
+ private[spark] var (schedulerBackend, taskScheduler) =
+ SparkContext.createTaskScheduler(this, master)
private val heartbeatReceiver = env.actorSystem.actorOf(
Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver")
@volatile private[spark] var dagScheduler: DAGScheduler = _
try {
dagScheduler = new DAGScheduler(this)
} catch {
- case e: Exception => throw
- new SparkException("DAGScheduler cannot be initialized due to %s".format(e.getMessage))
+ case e: Exception => {
+ try {
+ stop()
+ } finally {
+ throw new SparkException("Error while constructing DAGScheduler", e)
+ }
+ }
}
// start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's
@@ -313,11 +354,15 @@ class SparkContext(config: SparkConf) extends Logging {
val applicationId: String = taskScheduler.applicationId()
conf.set("spark.app.id", applicationId)
+ env.blockManager.initialize(applicationId)
+
val metricsSystem = env.metricsSystem
// The metrics system for Driver need to be set spark.app.id to app ID.
// So it should start after we get app ID from the task scheduler and set spark.app.id.
metricsSystem.start()
+ // Attach the driver metrics servlet handler to the web ui after the metrics system is started.
+ metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler)))
// Optionally log Spark events
private[spark] val eventLogger: Option[EventLoggingListener] = {
@@ -330,8 +375,18 @@ class SparkContext(config: SparkConf) extends Logging {
} else None
}
- // At this point, all relevant SparkListeners have been registered, so begin releasing events
- listenerBus.start()
+ // Optionally scale number of executors dynamically based on workload. Exposed for testing.
+ private val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false)
+ private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false)
+ private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] =
+ if (dynamicAllocationEnabled) {
+ assert(master.contains("yarn") || dynamicAllocationTesting,
+ "Dynamic allocation of executors is currently only supported in YARN mode")
+ Some(new ExecutorAllocationManager(this, listenerBus, conf))
+ } else {
+ None
+ }
+ executorAllocationManager.foreach(_.start())
private[spark] val cleaner: Option[ContextCleaner] = {
if (conf.getBoolean("spark.cleaner.referenceTracking", true)) {
@@ -342,6 +397,7 @@ class SparkContext(config: SparkConf) extends Logging {
}
cleaner.foreach(_.start())
+ setupAndStartListenerBus()
postEnvironmentUpdate()
postApplicationStart()
@@ -352,6 +408,29 @@ class SparkContext(config: SparkConf) extends Logging {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ /**
+ * Called by the web UI to obtain executor thread dumps. This method may be expensive.
+ * Logs an error and returns None if we failed to obtain a thread dump, which could occur due
+ * to an executor being dead or unresponsive or due to network issues while sending the thread
+ * dump message back to the driver.
+ */
+ private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = {
+ try {
+ if (executorId == SparkContext.DRIVER_IDENTIFIER) {
+ Some(Utils.getThreadDump())
+ } else {
+ val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
+ val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
+ Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
+ AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Exception getting thread dump from executor $executorId", e)
+ None
+ }
+ }
+
private[spark] def getLocalProperties: Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
@@ -386,7 +465,6 @@ class SparkContext(config: SparkConf) extends Logging {
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
- @deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
@@ -449,12 +527,12 @@ class SparkContext(config: SparkConf) extends Logging {
/** Distribute a local Scala collection to form an RDD.
*
- * @note Parallelize acts lazily. If `seq` is a mutable collection and is
- * altered after the call to parallelize and before the first action on the
- * RDD, the resultant RDD will reflect the modified collection. Pass a copy of
- * the argument to avoid this.
+ * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
+ * to parallelize and before the first action on the RDD, the resultant RDD will reflect the
+ * modified collection. Pass a copy of the argument to avoid this.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
@@ -470,6 +548,7 @@ class SparkContext(config: SparkConf) extends Logging {
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -479,6 +558,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
+ assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
@@ -512,6 +592,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, String)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -524,6 +605,82 @@ class SparkContext(config: SparkConf) extends Logging {
minPartitions).setName(path)
}
+
+ /**
+ * :: Experimental ::
+ *
+ * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file
+ * (useful for binary data)
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ *
+ * @note Small files are preferred; very large files may cause bad performance.
+ */
+ @Experimental
+ def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
+ RDD[(String, PortableDataStream)] = {
+ assertNotStopped()
+ val job = new NewHadoopJob(hadoopConfiguration)
+ NewFileInputFormat.addInputPath(job, new Path(path))
+ val updateConf = job.getConfiguration
+ new BinaryFileRDD(
+ this,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ updateConf,
+ minPartitions).setName(path)
+ }
+
+ /**
+ * :: Experimental ::
+ *
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * '''Note:''' We ensure that the byte array for each record in the resulting RDD
+ * has the provided record length.
+ *
+ * @param path Directory to the input data files
+ * @param recordLength The length at which to split the records
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ @Experimental
+ def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
+ : RDD[Array[Byte]] = {
+ assertNotStopped()
+ conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
+ val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
+ classOf[FixedLengthBinaryInputFormat],
+ classOf[LongWritable],
+ classOf[BytesWritable],
+ conf=conf)
+ val data = br.map { case (k, v) =>
+ val bytes = v.getBytes
+ assert(bytes.length == recordLength, "Byte array does not have correct length")
+ bytes
+ }
+ data
+ }
+
/**
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other
* necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable),
@@ -536,9 +693,10 @@ class SparkContext(config: SparkConf) extends Logging {
* @param minPartitions Minimum number of Hadoop Splits to generate.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def hadoopRDD[K, V](
conf: JobConf,
@@ -547,18 +705,20 @@ class SparkContext(config: SparkConf) extends Logging {
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat
- *
- * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
- * */
+ *
+ * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
+ */
def hadoopFile[K, V](
path: String,
inputFormatClass: Class[_ <: InputFormat[K, V]],
@@ -566,6 +726,7 @@ class SparkContext(config: SparkConf) extends Logging {
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
@@ -588,9 +749,10 @@ class SparkContext(config: SparkConf) extends Logging {
* }}}
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def hadoopFile[K, V, F <: InputFormat[K, V]]
(path: String, minPartitions: Int)
@@ -611,9 +773,10 @@ class SparkContext(config: SparkConf) extends Logging {
* }}}
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] =
@@ -635,9 +798,10 @@ class SparkContext(config: SparkConf) extends Logging {
* and extra configuration options to pass to the input format.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](
path: String,
@@ -645,6 +809,9 @@ class SparkContext(config: SparkConf) extends Logging {
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ assertNotStopped()
+ // The call to new NewHadoopJob automatically adds security credentials to conf,
+ // so we don't need to explicitly add them ourselves
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -656,30 +823,37 @@ class SparkContext(config: SparkConf) extends Logging {
* and extra configuration options to pass to the input format.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
conf: Configuration = hadoopConfiguration,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
- new NewHadoopRDD(this, fClass, kClass, vClass, conf)
+ assertNotStopped()
+ // Add necessary security credentials to the JobConf. Required to access secure HDFS.
+ val jconf = new JobConf(conf)
+ SparkHadoopUtil.get.addCredentials(jconf)
+ new NewHadoopRDD(this, fClass, kClass, vClass, jconf)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
minPartitions: Int
): RDD[(K, V)] = {
+ assertNotStopped()
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
@@ -687,13 +861,15 @@ class SparkContext(config: SparkConf) extends Logging {
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
* */
- def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
- ): RDD[(K, V)] =
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
+ }
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -712,15 +888,17 @@ class SparkContext(config: SparkConf) extends Logging {
* allow it to figure out the Writable class to use in the subclass case.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def sequenceFile[K, V]
(path: String, minPartitions: Int = defaultMinPartitions)
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
+ assertNotStopped()
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
@@ -742,6 +920,7 @@ class SparkContext(config: SparkConf) extends Logging {
path: String,
minPartitions: Int = defaultMinPartitions
): RDD[T] = {
+ assertNotStopped()
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader))
}
@@ -817,6 +996,13 @@ class SparkContext(config: SparkConf) extends Logging {
* The variable will be sent to each cluster only once.
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
+ assertNotStopped()
+ if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have created RDD broadcast variables but not used them:
+ logWarning("Can not directly broadcast RDDs; instead, call collect() and "
+ + "broadcast the result (see SPARK-5063)")
+ }
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
@@ -837,11 +1023,12 @@ class SparkContext(config: SparkConf) extends Logging {
case "local" => "file:" + uri.getPath
case _ => path
}
- addedFiles(key) = System.currentTimeMillis
+ val timestamp = System.currentTimeMillis
+ addedFiles(key) = timestamp
// Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory()), conf, env.securityManager,
- hadoopConfiguration)
+ hadoopConfiguration, timestamp, useCache = false)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
postEnvironmentUpdate()
@@ -856,6 +1043,50 @@ class SparkContext(config: SparkConf) extends Logging {
listenerBus.addListener(listener)
}
+ /**
+ * :: DeveloperApi ::
+ * Request an additional number of executors from the cluster manager.
+ * This is currently only supported in Yarn mode. Return whether the request is received.
+ */
+ @DeveloperApi
+ override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
+ assert(master.contains("yarn") || dynamicAllocationTesting,
+ "Requesting executors is currently only supported in YARN mode")
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.requestExecutors(numAdditionalExecutors)
+ case _ =>
+ logWarning("Requesting executors is only supported in coarse-grained mode")
+ false
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Request that the cluster manager kill the specified executors.
+ * This is currently only supported in Yarn mode. Return whether the request is received.
+ */
+ @DeveloperApi
+ override def killExecutors(executorIds: Seq[String]): Boolean = {
+ assert(master.contains("yarn") || dynamicAllocationTesting,
+ "Killing executors is currently only supported in YARN mode")
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.killExecutors(executorIds)
+ case _ =>
+ logWarning("Killing executors is only supported in coarse-grained mode")
+ false
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Request that cluster manager the kill the specified executor.
+ * This is currently only supported in Yarn mode. Return whether the request is received.
+ */
+ @DeveloperApi
+ override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId)
+
/** The version of Spark on which this application is running. */
def version = SPARK_VERSION
@@ -864,6 +1095,7 @@ class SparkContext(config: SparkConf) extends Logging {
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ assertNotStopped()
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
@@ -876,6 +1108,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
@DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
+ assertNotStopped()
val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
rddInfos.filter(_.isCached)
@@ -893,6 +1126,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
@DeveloperApi
def getExecutorStorageStatus: Array[StorageStatus] = {
+ assertNotStopped()
env.blockManager.master.getStorageStatus
}
@@ -902,6 +1136,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
@DeveloperApi
def getAllPools: Seq[Schedulable] = {
+ assertNotStopped()
// TODO(xiajunluan): We should take nested pools into account
taskScheduler.rootPool.schedulableQueue.toSeq
}
@@ -912,6 +1147,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
@DeveloperApi
def getPoolForName(pool: String): Option[Schedulable] = {
+ assertNotStopped()
Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
}
@@ -919,6 +1155,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ assertNotStopped()
taskScheduler.schedulingMode
}
@@ -993,7 +1230,19 @@ class SparkContext(config: SparkConf) extends Logging {
null
}
} else {
- env.httpFileServer.addJar(new File(uri.getPath))
+ try {
+ env.httpFileServer.addJar(new File(uri.getPath))
+ } catch {
+ case exc: FileNotFoundException =>
+ logError(s"Jar not found at $path")
+ null
+ case e: Exception =>
+ // For now just log an error but allow to go through so spark examples work.
+ // The spark examples don't really need the jar distributed since its also
+ // the app jar.
+ logError("Error adding jar (" + e + "), was the --addJars option used?")
+ null
+ }
}
// A JAR file which exists locally on every worker node
case "local" =>
@@ -1021,27 +1270,28 @@ class SparkContext(config: SparkConf) extends Logging {
/** Shut down the SparkContext. */
def stop() {
- postApplicationEnd()
- ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
- env.metricsSystem.report()
- metadataCleaner.cancel()
- env.actorSystem.stop(heartbeatReceiver)
- cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
- taskScheduler = null
- // TODO: Cache.stop()?
- env.stop()
- SparkEnv.set(null)
- listenerBus.stop()
- eventLogger.foreach(_.stop())
- logInfo("Successfully stopped SparkContext")
- } else {
- logInfo("SparkContext already stopped")
+ SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ postApplicationEnd()
+ ui.foreach(_.stop())
+ if (!stopped) {
+ stopped = true
+ env.metricsSystem.report()
+ metadataCleaner.cancel()
+ env.actorSystem.stop(heartbeatReceiver)
+ cleaner.foreach(_.stop())
+ dagScheduler.stop()
+ dagScheduler = null
+ taskScheduler = null
+ // TODO: Cache.stop()?
+ env.stop()
+ SparkEnv.set(null)
+ listenerBus.stop()
+ eventLogger.foreach(_.stop())
+ logInfo("Successfully stopped SparkContext")
+ SparkContext.clearActiveContext()
+ } else {
+ logInfo("SparkContext already stopped")
+ }
}
}
@@ -1104,14 +1354,15 @@ class SparkContext(config: SparkConf) extends Logging {
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- if (dagScheduler == null) {
- throw new SparkException("SparkContext has been shutdown")
+ if (stopped) {
+ throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite.shortForm)
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
resultHandler, localProperties.get)
+ progressBar.foreach(_.finishAll())
rdd.doCheckpoint()
}
@@ -1191,6 +1442,7 @@ class SparkContext(config: SparkConf) extends Logging {
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
+ assertNotStopped()
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
@@ -1213,6 +1465,7 @@ class SparkContext(config: SparkConf) extends Logging {
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
+ assertNotStopped()
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
@@ -1231,11 +1484,13 @@ class SparkContext(config: SparkConf) extends Logging {
* for more information.
*/
def cancelJobGroup(groupId: String) {
+ assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
+ assertNotStopped()
dagScheduler.cancelAllJobs()
}
@@ -1282,13 +1537,20 @@ class SparkContext(config: SparkConf) extends Logging {
def getCheckpointDir = checkpointDir
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
- def defaultParallelism: Int = taskScheduler.defaultParallelism
+ def defaultParallelism: Int = {
+ assertNotStopped()
+ taskScheduler.defaultParallelism
+ }
/** Default min number of partitions for Hadoop RDDs when not given by user */
@deprecated("use defaultMinPartitions", "1.0.0")
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
- /** Default min number of partitions for Hadoop RDDs when not given by user */
+ /**
+ * Default min number of partitions for Hadoop RDDs when not given by user
+ * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2.
+ * The reasons for this are discussed in https://github.com/mesos/spark/pull/718
+ */
def defaultMinPartitions: Int = math.min(defaultParallelism, 2)
private val nextShuffleId = new AtomicInteger(0)
@@ -1300,6 +1562,58 @@ class SparkContext(config: SparkConf) extends Logging {
/** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
+ /**
+ * Registers listeners specified in spark.extraListeners, then starts the listener bus.
+ * This should be called after all internal listeners have been registered with the listener bus
+ * (e.g. after the web UI and event logging listeners have been registered).
+ */
+ private def setupAndStartListenerBus(): Unit = {
+ // Use reflection to instantiate listeners specified via `spark.extraListeners`
+ try {
+ val listenerClassNames: Seq[String] =
+ conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "")
+ for (className <- listenerClassNames) {
+ // Use reflection to find the right constructor
+ val constructors = {
+ val listenerClass = Class.forName(className)
+ listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]]
+ }
+ val constructorTakingSparkConf = constructors.find { c =>
+ c.getParameterTypes.sameElements(Array(classOf[SparkConf]))
+ }
+ lazy val zeroArgumentConstructor = constructors.find { c =>
+ c.getParameterTypes.isEmpty
+ }
+ val listener: SparkListener = {
+ if (constructorTakingSparkConf.isDefined) {
+ constructorTakingSparkConf.get.newInstance(conf)
+ } else if (zeroArgumentConstructor.isDefined) {
+ zeroArgumentConstructor.get.newInstance()
+ } else {
+ throw new SparkException(
+ s"$className did not have a zero-argument constructor or a" +
+ " single-argument constructor that accepts SparkConf. Note: if the class is" +
+ " defined inside of another Scala class, then its constructors may accept an" +
+ " implicit parameter that references the enclosing class; in this case, you must" +
+ " define the listener as a top-level class in order to prevent this extra" +
+ " parameter from breaking Spark's ability to find a valid constructor.")
+ }
+ }
+ listenerBus.addListener(listener)
+ logInfo(s"Registered listener $className")
+ }
+ } catch {
+ case e: Exception =>
+ try {
+ stop()
+ } finally {
+ throw new SparkException(s"Exception when registering SparkListener", e)
+ }
+ }
+
+ listenerBus.start()
+ }
+
/** Post the application start event */
private def postApplicationStart() {
// Note: this code assumes that the task scheduler has been initialized and has contacted
@@ -1330,6 +1644,11 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] def cleanup(cleanupTime: Long) {
persistentRdds.clearOldValues(cleanupTime)
}
+
+ // In order to prevent multiple SparkContexts from being active at the same time, mark this
+ // context as having finished construction.
+ // NOTE: this must be placed at the end of the SparkContext constructor.
+ SparkContext.setActiveContext(this, allowMultipleContexts)
}
/**
@@ -1338,6 +1657,107 @@ class SparkContext(config: SparkConf) extends Logging {
*/
object SparkContext extends Logging {
+ /**
+ * Lock that guards access to global variables that track SparkContext construction.
+ */
+ private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object()
+
+ /**
+ * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`.
+ *
+ * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+ */
+ private var activeContext: Option[SparkContext] = None
+
+ /**
+ * Points to a partially-constructed SparkContext if some thread is in the SparkContext
+ * constructor, or `None` if no SparkContext is being constructed.
+ *
+ * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+ */
+ private var contextBeingConstructed: Option[SparkContext] = None
+
+ /**
+ * Called to ensure that no other SparkContext is running in this JVM.
+ *
+ * Throws an exception if a running context is detected and logs a warning if another thread is
+ * constructing a SparkContext. This warning is necessary because the current locking scheme
+ * prevents us from reliably distinguishing between cases where another context is being
+ * constructed and cases where another constructor threw an exception.
+ */
+ private def assertNoOtherContextIsRunning(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ contextBeingConstructed.foreach { otherContext =>
+ if (otherContext ne sc) { // checks for reference equality
+ // Since otherContext might point to a partially-constructed context, guard against
+ // its creationSite field being null:
+ val otherContextCreationSite =
+ Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location")
+ val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" +
+ " constructor). This may indicate an error, since only one SparkContext may be" +
+ " running in this JVM (see SPARK-2243)." +
+ s" The other SparkContext was created at:\n$otherContextCreationSite"
+ logWarning(warnMsg)
+ }
+
+ activeContext.foreach { ctx =>
+ val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." +
+ " To ignore this error, set spark.driver.allowMultipleContexts = true. " +
+ s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}"
+ val exception = new SparkException(errMsg)
+ if (allowMultipleContexts) {
+ logWarning("Multiple running SparkContexts detected in the same JVM!", exception)
+ } else {
+ throw exception
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is
+ * running. Throws an exception if a running context is detected and logs a warning if another
+ * thread is constructing a SparkContext. This warning is necessary because the current locking
+ * scheme prevents us from reliably distinguishing between cases where another context is being
+ * constructed and cases where another constructor threw an exception.
+ */
+ private[spark] def markPartiallyConstructed(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+ contextBeingConstructed = Some(sc)
+ }
+ }
+
+ /**
+ * Called at the end of the SparkContext constructor to ensure that no other SparkContext has
+ * raced with this constructor and started.
+ */
+ private[spark] def setActiveContext(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+ contextBeingConstructed = None
+ activeContext = Some(sc)
+ }
+ }
+
+ /**
+ * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's
+ * also called in unit tests to prevent a flood of warnings from test suites that don't / can't
+ * properly clean up their SparkContexts.
+ */
+ private[spark] def clearActiveContext(): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ activeContext = None
+ }
+ }
+
private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
@@ -1346,63 +1766,113 @@ object SparkContext extends Logging {
private[spark] val SPARK_UNKNOWN_USER = ""
- implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
+ private[spark] val DRIVER_IDENTIFIER = ""
+
+ // The following deprecated objects have already been copied to `object AccumulatorParam` to
+ // make the compiler find them automatically. They are duplicate codes only for backward
+ // compatibility, please update `object AccumulatorParam` accordingly if you plan to modify the
+ // following ones.
+
+ @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
}
- implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
+ @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ object IntAccumulatorParam extends AccumulatorParam[Int] {
def addInPlace(t1: Int, t2: Int): Int = t1 + t2
def zero(initialValue: Int) = 0
}
- implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
+ @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ object LongAccumulatorParam extends AccumulatorParam[Long] {
def addInPlace(t1: Long, t2: Long) = t1 + t2
def zero(initialValue: Long) = 0L
}
- implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
+ @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ object FloatAccumulatorParam extends AccumulatorParam[Float] {
def addInPlace(t1: Float, t2: Float) = t1 + t2
def zero(initialValue: Float) = 0f
}
- // TODO: Add AccumulatorParams for other types, e.g. lists and strings
+ // The following deprecated functions have already been moved to `object RDD` to
+ // make the compiler find them automatically. They are still kept here for backward compatibility
+ // and just call the corresponding functions in `object RDD`.
- implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)])
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = {
- new PairRDDFunctions(rdd)
+ RDD.rddToPairRDDFunctions(rdd)
}
- implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd)
- implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
- rdd: RDD[(K, V)]) =
- new SequenceFileRDDFunctions(rdd)
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
+ rdd: RDD[(K, V)]) = {
+ val kf = implicitly[K => Writable]
+ val vf = implicitly[V => Writable]
+ // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it
+ implicit val keyWritableFactory = new WritableFactory[K](_ => null, kf)
+ implicit val valueWritableFactory = new WritableFactory[V](_ => null, vf)
+ RDD.rddToSequenceFileRDDFunctions(rdd)
+ }
- implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](
rdd: RDD[(K, V)]) =
- new OrderedRDDFunctions[K, V, (K, V)](rdd)
+ RDD.rddToOrderedRDDFunctions(rdd)
- implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd)
- implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
- new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
+ @deprecated("Replaced by implicit functions in the RDD companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
+ RDD.numericRDDToDoubleRDDFunctions(rdd)
- // Implicit conversions to common Writable types, for saveAsSequenceFile
+ // The following deprecated functions have already been moved to `object WritableFactory` to
+ // make the compiler find them automatically. They are still kept here for backward compatibility.
- implicit def intToIntWritable(i: Int) = new IntWritable(i)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def intToIntWritable(i: Int): IntWritable = new IntWritable(i)
- implicit def longToLongWritable(l: Long) = new LongWritable(l)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def longToLongWritable(l: Long): LongWritable = new LongWritable(l)
- implicit def floatToFloatWritable(f: Float) = new FloatWritable(f)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def floatToFloatWritable(f: Float): FloatWritable = new FloatWritable(f)
- implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def doubleToDoubleWritable(d: Double): DoubleWritable = new DoubleWritable(d)
- implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def boolToBoolWritable (b: Boolean): BooleanWritable = new BooleanWritable(b)
- implicit def bytesToBytesWritable (aob: Array[Byte]) = new BytesWritable(aob)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def bytesToBytesWritable (aob: Array[Byte]): BytesWritable = new BytesWritable(aob)
- implicit def stringToText(s: String) = new Text(s)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def stringToText(s: String): Text = new Text(s)
private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T])
: ArrayWritable = {
@@ -1412,40 +1882,49 @@ object SparkContext extends Logging {
arr.map(x => anyToWritable(x)).toArray)
}
- // Helper objects for converting common types to Writable
- private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T)
- : WritableConverter[T] = {
- val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]]
- new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
- }
+ // The following deprecated functions have already been moved to `object WritableConverter` to
+ // make the compiler find them automatically. They are still kept here for backward compatibility
+ // and just call the corresponding functions in `object WritableConverter`.
- implicit def intWritableConverter(): WritableConverter[Int] =
- simpleWritableConverter[Int, IntWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ def intWritableConverter(): WritableConverter[Int] =
+ WritableConverter.intWritableConverter()
- implicit def longWritableConverter(): WritableConverter[Long] =
- simpleWritableConverter[Long, LongWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ def longWritableConverter(): WritableConverter[Long] =
+ WritableConverter.longWritableConverter()
- implicit def doubleWritableConverter(): WritableConverter[Double] =
- simpleWritableConverter[Double, DoubleWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ def doubleWritableConverter(): WritableConverter[Double] =
+ WritableConverter.doubleWritableConverter()
- implicit def floatWritableConverter(): WritableConverter[Float] =
- simpleWritableConverter[Float, FloatWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ def floatWritableConverter(): WritableConverter[Float] =
+ WritableConverter.floatWritableConverter()
- implicit def booleanWritableConverter(): WritableConverter[Boolean] =
- simpleWritableConverter[Boolean, BooleanWritable](_.get)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ def booleanWritableConverter(): WritableConverter[Boolean] =
+ WritableConverter.booleanWritableConverter()
- implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = {
- simpleWritableConverter[Array[Byte], BytesWritable](bw =>
- // getBytes method returns array which is longer then data to be returned
- Arrays.copyOfRange(bw.getBytes, 0, bw.getLength)
- )
- }
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ def bytesWritableConverter(): WritableConverter[Array[Byte]] =
+ WritableConverter.bytesWritableConverter()
- implicit def stringWritableConverter(): WritableConverter[String] =
- simpleWritableConverter[String, Text](_.toString)
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ def stringWritableConverter(): WritableConverter[String] =
+ WritableConverter.stringWritableConverter()
- implicit def writableWritableConverter[T <: Writable]() =
- new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
+ @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " +
+ "backward compatibility.", "1.3.0")
+ def writableWritableConverter[T <: Writable](): WritableConverter[T] =
+ WritableConverter.writableWritableConverter()
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
@@ -1501,8 +1980,13 @@ object SparkContext extends Logging {
res
}
- /** Creates a task scheduler based on a given master URL. Extracted for testing. */
- private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = {
+ /**
+ * Create a task scheduler based on a given master URL.
+ * Return a 2-tuple of the scheduler backend and the task scheduler.
+ */
+ private def createTaskScheduler(
+ sc: SparkContext,
+ master: String): (SchedulerBackend, TaskScheduler) = {
// Regular expression used for local[N] and local[*] master formats
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
@@ -1524,16 +2008,19 @@ object SparkContext extends Logging {
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalBackend(scheduler, 1)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case LOCAL_N_REGEX(threads) =>
def localCpuCount = Runtime.getRuntime.availableProcessors()
// local[*] estimates the number of cores on the machine; local[N] uses exactly N threads.
val threadCount = if (threads == "*") localCpuCount else threads.toInt
+ if (threadCount <= 0) {
+ throw new SparkException(s"Asked to run locally with $threadCount threads")
+ }
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
def localCpuCount = Runtime.getRuntime.availableProcessors()
@@ -1543,14 +2030,14 @@ object SparkContext extends Logging {
val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new TaskSchedulerImpl(sc)
val masterUrls = sparkUrl.split(",").map("spark://" + _)
val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
@@ -1570,7 +2057,7 @@ object SparkContext extends Logging {
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
}
- scheduler
+ (backend, scheduler)
case "yarn-standalone" | "yarn-cluster" =>
if (master == "yarn-standalone") {
@@ -1599,12 +2086,12 @@ object SparkContext extends Logging {
}
}
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case "yarn-client" =>
val scheduler = try {
val clazz =
- Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
@@ -1626,7 +2113,7 @@ object SparkContext extends Logging {
}
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case mesosUrl @ MESOS_REGEX(_) =>
MesosNativeLibrary.load()
@@ -1639,13 +2126,13 @@ object SparkContext extends Logging {
new MesosSchedulerBackend(scheduler, sc, url)
}
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case SIMR_REGEX(simrUrl) =>
val scheduler = new TaskSchedulerImpl(sc)
val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case _ =>
throw new SparkException("Could not parse Master URL: '" + master + "'")
@@ -1664,3 +2151,89 @@ private[spark] class WritableConverter[T](
val writableClass: ClassTag[T] => Class[_ <: Writable],
val convert: Writable => T)
extends Serializable
+
+object WritableConverter {
+
+ // Helper objects for converting common types to Writable
+ private[spark] def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T)
+ : WritableConverter[T] = {
+ val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]]
+ new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
+ }
+
+ // The following implicit functions were in SparkContext before 1.3 and users had to
+ // `import SparkContext._` to enable them. Now we move them here to make the compiler find
+ // them automatically. However, we still keep the old functions in SparkContext for backward
+ // compatibility and forward to the following functions directly.
+
+ implicit def intWritableConverter(): WritableConverter[Int] =
+ simpleWritableConverter[Int, IntWritable](_.get)
+
+ implicit def longWritableConverter(): WritableConverter[Long] =
+ simpleWritableConverter[Long, LongWritable](_.get)
+
+ implicit def doubleWritableConverter(): WritableConverter[Double] =
+ simpleWritableConverter[Double, DoubleWritable](_.get)
+
+ implicit def floatWritableConverter(): WritableConverter[Float] =
+ simpleWritableConverter[Float, FloatWritable](_.get)
+
+ implicit def booleanWritableConverter(): WritableConverter[Boolean] =
+ simpleWritableConverter[Boolean, BooleanWritable](_.get)
+
+ implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = {
+ simpleWritableConverter[Array[Byte], BytesWritable] { bw =>
+ // getBytes method returns array which is longer then data to be returned
+ Arrays.copyOfRange(bw.getBytes, 0, bw.getLength)
+ }
+ }
+
+ implicit def stringWritableConverter(): WritableConverter[String] =
+ simpleWritableConverter[String, Text](_.toString)
+
+ implicit def writableWritableConverter[T <: Writable](): WritableConverter[T] =
+ new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
+}
+
+/**
+ * A class encapsulating how to convert some type T to Writable. It stores both the Writable class
+ * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion.
+ * The Writable class will be used in `SequenceFileRDDFunctions`.
+ */
+private[spark] class WritableFactory[T](
+ val writableClass: ClassTag[T] => Class[_ <: Writable],
+ val convert: T => Writable) extends Serializable
+
+object WritableFactory {
+
+ private[spark] def simpleWritableFactory[T: ClassTag, W <: Writable : ClassTag](convert: T => W)
+ : WritableFactory[T] = {
+ val writableClass = implicitly[ClassTag[W]].runtimeClass.asInstanceOf[Class[W]]
+ new WritableFactory[T](_ => writableClass, convert)
+ }
+
+ implicit def intWritableFactory: WritableFactory[Int] =
+ simpleWritableFactory(new IntWritable(_))
+
+ implicit def longWritableFactory: WritableFactory[Long] =
+ simpleWritableFactory(new LongWritable(_))
+
+ implicit def floatWritableFactory: WritableFactory[Float] =
+ simpleWritableFactory(new FloatWritable(_))
+
+ implicit def doubleWritableFactory: WritableFactory[Double] =
+ simpleWritableFactory(new DoubleWritable(_))
+
+ implicit def booleanWritableFactory: WritableFactory[Boolean] =
+ simpleWritableFactory(new BooleanWritable(_))
+
+ implicit def bytesWritableFactory: WritableFactory[Array[Byte]] =
+ simpleWritableFactory(new BytesWritable(_))
+
+ implicit def stringWritableFactory: WritableFactory[String] =
+ simpleWritableFactory(new Text(_))
+
+ implicit def writableWritableFactory[T <: Writable: ClassTag]: WritableFactory[T] =
+ simpleWritableFactory(w => w)
+
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index aba713cb4267a..f25db7f8de565 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
@@ -68,6 +69,7 @@ class SparkEnv (
val shuffleMemoryManager: ShuffleMemoryManager,
val conf: SparkConf) extends Logging {
+ private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
// A general, soft-reference map for metadata needed during HadoopRDD split computation
@@ -75,6 +77,7 @@ class SparkEnv (
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
private[spark] def stop() {
+ isStopped = true
pythonWorkers.foreach { case(key, worker) => worker.stop() }
Option(httpFileServer).foreach(_.stop())
mapOutputTracker.stop()
@@ -142,14 +145,64 @@ object SparkEnv extends Logging {
env
}
- private[spark] def create(
+ /**
+ * Create a SparkEnv for the driver.
+ */
+ private[spark] def createDriverEnv(
+ conf: SparkConf,
+ isLocal: Boolean,
+ listenerBus: LiveListenerBus): SparkEnv = {
+ assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
+ assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
+ val hostname = conf.get("spark.driver.host")
+ val port = conf.get("spark.driver.port").toInt
+ create(
+ conf,
+ SparkContext.DRIVER_IDENTIFIER,
+ hostname,
+ port,
+ isDriver = true,
+ isLocal = isLocal,
+ listenerBus = listenerBus
+ )
+ }
+
+ /**
+ * Create a SparkEnv for an executor.
+ * In coarse-grained mode, the executor provides an actor system that is already instantiated.
+ */
+ private[spark] def createExecutorEnv(
+ conf: SparkConf,
+ executorId: String,
+ hostname: String,
+ port: Int,
+ numCores: Int,
+ isLocal: Boolean): SparkEnv = {
+ val env = create(
+ conf,
+ executorId,
+ hostname,
+ port,
+ isDriver = false,
+ isLocal = isLocal,
+ numUsableCores = numCores
+ )
+ SparkEnv.set(env)
+ env
+ }
+
+ /**
+ * Helper method to create a SparkEnv for a driver or an executor.
+ */
+ private def create(
conf: SparkConf,
executorId: String,
hostname: String,
port: Int,
isDriver: Boolean,
isLocal: Boolean,
- listenerBus: LiveListenerBus = null): SparkEnv = {
+ listenerBus: LiveListenerBus = null,
+ numUsableCores: Int = 0): SparkEnv = {
// Listener bus is only used on the driver
if (isDriver) {
@@ -157,14 +210,18 @@ object SparkEnv extends Logging {
}
val securityManager = new SecurityManager(conf)
- val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- actorSystemName, hostname, port, conf, securityManager)
+
+ // Create the ActorSystem for Akka and get the port it binds to.
+ val (actorSystem, boundPort) = {
+ val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
+ AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)
+ }
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
- // This is so that we tell the executors the correct port to connect to.
if (isDriver) {
conf.set("spark.driver.port", boundPort.toString)
+ } else {
+ conf.set("spark.executor.port", boundPort.toString)
}
// Create an instance of the class with the given name, possibly initializing it with our conf
@@ -231,14 +288,22 @@ object SparkEnv extends Logging {
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
- val blockTransferService = new NioBlockTransferService(conf, securityManager)
+ val blockTransferService =
+ conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
+ case "netty" =>
+ new NettyBlockTransferService(conf, securityManager, numUsableCores)
+ case "nio" =>
+ new NioBlockTransferService(conf, securityManager)
+ }
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
+ // NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
+ serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager,
+ numUsableCores)
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
@@ -247,7 +312,7 @@ object SparkEnv extends Logging {
val httpFileServer =
if (isDriver) {
val fileServerPort = conf.getInt("spark.fileserver.port", 0)
- val server = new HttpFileServer(securityManager, fileServerPort)
+ val server = new HttpFileServer(conf, securityManager, fileServerPort)
server.initialize()
conf.set("spark.fileserver.uri", server.serverUri)
server
@@ -261,6 +326,10 @@ object SparkEnv extends Logging {
// Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
+ // We need to set the executor ID before the MetricsSystem is created because sources and
+ // sinks specified in the metrics configuration file will want to incorporate this executor's
+ // ID into the metrics they report.
+ conf.set("spark.executor.id", executorId)
val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
ms.start()
ms
@@ -270,7 +339,7 @@ object SparkEnv extends Logging {
// this is a temporary directory; in distributed mode, this is the executor's current working
// directory.
val sparkFilesDir: String = if (isDriver) {
- Utils.createTempDir().getAbsolutePath
+ Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
} else {
"."
}
@@ -330,7 +399,7 @@ object SparkEnv extends Logging {
val sparkProperties = (conf.getAll ++ schedulerMode).sorted
// System properties that are not java classpaths
- val systemProperties = System.getProperties.iterator.toSeq
+ val systemProperties = Utils.getSystemProperties.toSeq
val otherProperties = systemProperties.filter { case (k, _) =>
k != "java.class.path" && !k.startsWith("spark.")
}.sorted
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 376e69cd997d5..40237596570de 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.HadoopRDD
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
deleted file mode 100644
index 65003b6ac6a0a..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import java.io.IOException
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.RealmChoiceCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslClient
-import javax.security.sasl.SaslException
-
-import scala.collection.JavaConversions.mapAsJavaMap
-
-/**
- * Implements SASL Client logic for Spark
- */
-private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Used to respond to server's counterpart, SaslServer with SASL tokens
- * represented as byte arrays.
- *
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
- null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslClientCallbackHandler(securityMgr))
-
- /**
- * Used to initiate SASL handshake with server.
- * @return response to challenge if needed
- */
- def firstToken(): Array[Byte] = {
- synchronized {
- val saslToken: Array[Byte] =
- if (saslClient != null && saslClient.hasInitialResponse()) {
- logDebug("has initial response")
- saslClient.evaluateChallenge(new Array[Byte](0))
- } else {
- new Array[Byte](0)
- }
- saslToken
- }
- }
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslClient != null) saslClient.isComplete() else false
- }
- }
-
- /**
- * Respond to server's SASL token.
- * @param saslTokenMessage contains server's SASL token
- * @return client's response SASL token
- */
- def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslClient might be using.
- */
- def dispose() {
- synchronized {
- if (saslClient != null) {
- try {
- saslClient.dispose()
- } catch {
- case e: SaslException => // ignored
- } finally {
- saslClient = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * that works with share secrets.
- */
- private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
- CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
- private val secretKey = securityMgr.getSecretKey()
- private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
- if (secretKey != null) secretKey.getBytes("utf-8") else "".getBytes("utf-8"))
-
- /**
- * Implementation used to respond to SASL request from the server.
- *
- * @param callbacks objects that indicate what credential information the
- * server's SaslServer requires from the client.
- */
- override def handle(callbacks: Array[Callback]) {
- logDebug("in the sasl client callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL client callback: setting username: " + userName)
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL client callback: setting userPassword")
- pc.setPassword(userPassword)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case cb: RealmChoiceCallback => {}
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
deleted file mode 100644
index f6b0a9132aca4..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
+++ /dev/null
@@ -1,174 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.AuthorizeCallback
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslException
-import javax.security.sasl.SaslServer
-import scala.collection.JavaConversions.mapAsJavaMap
-import org.apache.commons.net.util.Base64
-
-/**
- * Encapsulates SASL server logic
- */
-private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Actual SASL work done by this object from javax.security.sasl.
- */
- private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
- SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslDigestCallbackHandler(securityMgr))
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslServer != null) saslServer.isComplete() else false
- }
- }
-
- /**
- * Used to respond to server SASL tokens.
- * @param token Server's SASL token
- * @return response to send back to the server.
- */
- def response(token: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslServer might be using.
- */
- def dispose() {
- synchronized {
- if (saslServer != null) {
- try {
- saslServer.dispose()
- } catch {
- case e: SaslException => // ignore
- } finally {
- saslServer = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * for SASL DIGEST-MD5 mechanism
- */
- private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
- extends CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
-
- override def handle(callbacks: Array[Callback]) {
- logDebug("In the sasl server callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL server callback: setting username")
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL server callback: setting userPassword")
- val password: Array[Char] =
- SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes("utf-8"))
- pc.setPassword(password)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case ac: AuthorizeCallback => {
- val authid = ac.getAuthenticationID()
- val authzid = ac.getAuthorizationID()
- if (authid.equals(authzid)) {
- logDebug("set auth to true")
- ac.setAuthorized(true)
- } else {
- logDebug("set auth to false")
- ac.setAuthorized(false)
- }
- if (ac.isAuthorized()) {
- logDebug("sasl server is authorized")
- ac.setAuthorizedID(authzid)
- }
- }
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
- }
- }
- }
-}
-
-private[spark] object SparkSaslServer {
-
- /**
- * This is passed as the server name when creating the sasl client/server.
- * This could be changed to be configurable in the future.
- */
- val SASL_DEFAULT_REALM = "default"
-
- /**
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- val DIGEST = "DIGEST-MD5"
-
- /**
- * The quality of protection is just "auth". This means that we are doing
- * authentication only, we are not supporting integrity or privacy protection of the
- * communication channel after authentication. This could be changed to be configurable
- * in the future.
- */
- val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
-
- /**
- * Encode a byte[] identifier as a Base64-encoded string.
- *
- * @param identifier identifier to encode
- * @return Base64-encoded string
- */
- def encodeIdentifier(identifier: Array[Byte]): String = {
- new String(Base64.encodeBase64(identifier), "utf-8")
- }
-
- /**
- * Encode a password as a base64-encoded char[] array.
- * @param password as a byte array.
- * @return password as a char array.
- */
- def encodePassword(password: Array[Byte]): Array[Char] = {
- new String(Base64.encodeBase64(password), "utf-8").toCharArray()
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
new file mode 100644
index 0000000000000..edbdda8a0bcb6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Low-level status reporting APIs for monitoring job and stage progress.
+ *
+ * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should
+ * be prepared to handle empty / missing information. For example, a job's stage ids may be known
+ * but the status API may not have any information about the details of those stages, so
+ * `getStageInfo` could potentially return `None` for a valid stage id.
+ *
+ * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs
+ * will provide information for the last `spark.ui.retainedStages` stages and
+ * `spark.ui.retainedJobs` jobs.
+ *
+ * NOTE: this class's constructor should be considered private and may be subject to change.
+ */
+class SparkStatusTracker private[spark] (sc: SparkContext) {
+
+ private val jobProgressListener = sc.jobProgressListener
+
+ /**
+ * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then
+ * returns all known jobs that are not associated with a job group.
+ *
+ * The returned list may contain running, failed, and completed jobs, and may vary across
+ * invocations of this method. This method does not guarantee the order of the elements in
+ * its result.
+ */
+ def getJobIdsForGroup(jobGroup: String): Array[Int] = {
+ jobProgressListener.synchronized {
+ val jobData = jobProgressListener.jobIdToData.valuesIterator
+ jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray
+ }
+ }
+
+ /**
+ * Returns an array containing the ids of all active stages.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveStageIds(): Array[Int] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.activeStages.values.map(_.stageId).toArray
+ }
+ }
+
+ /**
+ * Returns an array containing the ids of all active jobs.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveJobIds(): Array[Int] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.activeJobs.values.map(_.jobId).toArray
+ }
+ }
+
+ /**
+ * Returns job information, or `None` if the job info could not be found or was garbage collected.
+ */
+ def getJobInfo(jobId: Int): Option[SparkJobInfo] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.jobIdToData.get(jobId).map { data =>
+ new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status)
+ }
+ }
+ }
+
+ /**
+ * Returns stage information, or `None` if the stage info could not be found or was
+ * garbage collected.
+ */
+ def getStageInfo(stageId: Int): Option[SparkStageInfo] = {
+ jobProgressListener.synchronized {
+ for (
+ info <- jobProgressListener.stageIdToInfo.get(stageId);
+ data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId))
+ ) yield {
+ new SparkStageInfoImpl(
+ stageId,
+ info.attemptId,
+ info.submissionTime.getOrElse(0),
+ info.name,
+ info.numTasks,
+ data.numActiveTasks,
+ data.numCompleteTasks,
+ data.numFailedTasks)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
new file mode 100644
index 0000000000000..e5c7c8d0db578
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+private class SparkJobInfoImpl (
+ val jobId: Int,
+ val stageIds: Array[Int],
+ val status: JobExecutionStatus)
+ extends SparkJobInfo
+
+private class SparkStageInfoImpl(
+ val stageId: Int,
+ val currentAttemptId: Int,
+ val submissionTime: Long,
+ val name: String,
+ val numTasks: Int,
+ val numActiveTasks: Int,
+ val numCompletedTasks: Int,
+ val numFailedTasks: Int)
+ extends SparkStageInfo
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
new file mode 100644
index 0000000000000..7d7fe1a446313
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.Serializable
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.TaskCompletionListener
+
+
+object TaskContext {
+ /**
+ * Return the currently active TaskContext. This can be called inside of
+ * user functions to access contextual information about running tasks.
+ */
+ def get(): TaskContext = taskContext.get
+
+ private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]
+
+ // Note: protected[spark] instead of private[spark] to prevent the following two from
+ // showing up in JavaDoc.
+ /**
+ * Set the thread local TaskContext. Internal to Spark.
+ */
+ protected[spark] def setTaskContext(tc: TaskContext): Unit = taskContext.set(tc)
+
+ /**
+ * Unset the thread local TaskContext. Internal to Spark.
+ */
+ protected[spark] def unset(): Unit = taskContext.remove()
+}
+
+
+/**
+ * Contextual information about a task which can be read or mutated during
+ * execution. To access the TaskContext for a running task, use:
+ * {{{
+ * org.apache.spark.TaskContext.get()
+ * }}}
+ */
+abstract class TaskContext extends Serializable {
+ // Note: TaskContext must NOT define a get method. Otherwise it will prevent the Scala compiler
+ // from generating a static get method (based on the companion object's get method).
+
+ // Note: Update JavaTaskContextCompileCheck when new methods are added to this class.
+
+ // Note: getters in this class are defined with parentheses to maintain backward compatibility.
+
+ /**
+ * Returns true if the task has completed.
+ */
+ def isCompleted(): Boolean
+
+ /**
+ * Returns true if the task has been killed.
+ */
+ def isInterrupted(): Boolean
+
+ @deprecated("use isRunningLocally", "1.2.0")
+ def runningLocally(): Boolean
+
+ /**
+ * Returns true if the task is running locally in the driver program.
+ * @return
+ */
+ def isRunningLocally(): Boolean
+
+ /**
+ * Adds a (Java friendly) listener to be executed on task completion.
+ * This will be called in all situation - success, failure, or cancellation.
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
+
+ /**
+ * Adds a listener in the form of a Scala closure to be executed on task completion.
+ * This will be called in all situations - success, failure, or cancellation.
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext
+
+ /**
+ * Adds a callback function to be executed on task completion. An example use
+ * is for HadoopRDD to register a callback to close the input stream.
+ * Will be called in any situation - success, failure, or cancellation.
+ *
+ * @param f Callback function.
+ */
+ @deprecated("use addTaskCompletionListener", "1.2.0")
+ def addOnCompleteCallback(f: () => Unit)
+
+ /**
+ * The ID of the stage that this task belong to.
+ */
+ def stageId(): Int
+
+ /**
+ * The ID of the RDD partition that is computed by this task.
+ */
+ def partitionId(): Int
+
+ /**
+ * How many times this task has been attempted. The first task attempt will be assigned
+ * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
+ */
+ def attemptNumber(): Int
+
+ @deprecated("use attemptNumber", "1.3.0")
+ def attemptId(): Long
+
+ /**
+ * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts
+ * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID.
+ */
+ def taskAttemptId(): Long
+
+ /** ::DeveloperApi:: */
+ @DeveloperApi
+ def taskMetrics(): TaskMetrics
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index afd2b85d33a77..337c8e4ebebcd 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -22,14 +22,19 @@ import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerExce
import scala.collection.mutable.ArrayBuffer
-private[spark] class TaskContextImpl(val stageId: Int,
+private[spark] class TaskContextImpl(
+ val stageId: Int,
val partitionId: Int,
- val attemptId: Long,
+ override val taskAttemptId: Long,
+ override val attemptNumber: Int,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {
+ // For backwards-compatibility; this method is now deprecated as of 1.3.0.
+ override def attemptId(): Long = taskAttemptId
+
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
@@ -82,10 +87,10 @@ private[spark] class TaskContextImpl(val stageId: Int,
interrupted = true
}
- override def isCompleted: Boolean = completed
+ override def isCompleted(): Boolean = completed
- override def isRunningLocally: Boolean = runningLocally
+ override def isRunningLocally(): Boolean = runningLocally
- override def isInterrupted: Boolean = interrupted
+ override def isInterrupted(): Boolean = interrupted
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 8f0c5e78416c2..af5fd8e0ac00c 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -69,11 +69,13 @@ case class FetchFailed(
bmAddress: BlockManagerId, // Note that bmAddress can be null
shuffleId: Int,
mapId: Int,
- reduceId: Int)
+ reduceId: Int,
+ message: String)
extends TaskFailedReason {
override def toErrorString: String = {
val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString
- s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)"
+ s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " +
+ s"message=\n$message\n)"
}
}
@@ -81,15 +83,48 @@ case class FetchFailed(
* :: DeveloperApi ::
* Task failed due to a runtime exception. This is the most common failure case and also captures
* user program exceptions.
+ *
+ * `stackTrace` contains the stack trace of the exception itself. It still exists for backward
+ * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to
+ * create `ExceptionFailure` as it will handle the backward compatibility properly.
+ *
+ * `fullStackTrace` is a better representation of the stack trace because it contains the whole
+ * stack trace including the exception and its causes
*/
@DeveloperApi
case class ExceptionFailure(
className: String,
description: String,
stackTrace: Array[StackTraceElement],
+ fullStackTrace: String,
metrics: Option[TaskMetrics])
extends TaskFailedReason {
- override def toErrorString: String = Utils.exceptionString(className, description, stackTrace)
+
+ private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
+ this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
+ }
+
+ override def toErrorString: String =
+ if (fullStackTrace == null) {
+ // fullStackTrace is added in 1.2.0
+ // If fullStackTrace is null, use the old error string for backward compatibility
+ exceptionString(className, description, stackTrace)
+ } else {
+ fullStackTrace
+ }
+
+ /**
+ * Return a nice string representation of the exception, including the stack trace.
+ * Note: It does not include the exception's causes, and is only used for backward compatibility.
+ */
+ private def exceptionString(
+ className: String,
+ description: String,
+ stackTrace: Array[StackTraceElement]): String = {
+ val desc = if (description == null) "" else description
+ val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n")
+ s"$className: $desc\n$st"
+ }
}
/**
@@ -117,8 +152,8 @@ case object TaskKilled extends TaskFailedReason {
* the task crashed the JVM.
*/
@DeveloperApi
-case object ExecutorLostFailure extends TaskFailedReason {
- override def toErrorString: String = "ExecutorLostFailure (executor lost)"
+case class ExecutorLostFailure(execId: String) extends TaskFailedReason {
+ override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)"
}
/**
diff --git a/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
new file mode 100644
index 0000000000000..9df61062e1f85
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * Exception thrown when a task cannot be serialized.
+ */
+private[spark] class TaskNotSerializableException(error: Throwable) extends Exception(error)
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index e72826dc25f41..34078142f5385 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -23,8 +23,8 @@ import java.util.jar.{JarEntry, JarOutputStream}
import scala.collection.JavaConversions._
+import com.google.common.io.{ByteStreams, Files}
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
-import com.google.common.io.Files
import org.apache.spark.util.Utils
@@ -64,12 +64,7 @@ private[spark] object TestUtils {
jarStream.putNextEntry(jarEntry)
val in = new FileInputStream(file)
- val buffer = new Array[Byte](10240)
- var nRead = 0
- while (nRead <= 0) {
- nRead = in.read(buffer, 0, buffer.length)
- jarStream.write(buffer, 0, nRead)
- }
+ ByteStreams.copy(in, jarStream)
in.close()
}
jarStream.close()
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index a6123bd108c11..8e8f7f6c4fda2 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -114,7 +114,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtract(other: JavaDoubleRDD): JavaDoubleRDD =
fromRDD(srdd.subtract(other))
@@ -233,11 +233,11 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
* to the left except for the last which is closed
* e.g. for the array
* [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
- * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
* And on the input of 1 and 50 we would have a histogram of 1,0,0
*
* Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
- * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets
* to true.
* buckets must be sorted and not contain any duplicates.
* buckets array must be at least two elements
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index c38b96528d037..7af3538262fd6 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -32,13 +32,13 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.spark.{HashPartitioner, Partitioner}
import org.apache.spark.Partitioner._
-import org.apache.spark.SparkContext.rddToPairRDDFunctions
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
+import org.apache.spark.rdd.RDD.rddToPairRDDFunctions
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -392,7 +392,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
fromRDD(rdd.subtract(other))
@@ -413,7 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return an RDD with the pairs from `this` whose keys are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtractByKey[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, V] = {
implicit val ctag: ClassTag[W] = fakeClassTag
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index efb8978f7ce12..0f91c942ecd50 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -28,7 +28,6 @@ import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark._
-import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
@@ -39,6 +38,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
+/**
+ * Defines operations common to several Java RDD implementations.
+ * Note that this trait is not intended to be implemented by user code.
+ */
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
@@ -212,8 +215,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JIterable[T]] = {
- implicit val ctagK: ClassTag[K] = fakeClassTag
+ def groupBy[U](f: JFunction[T, U]): JavaPairRDD[U, JIterable[T]] = {
+ // The type parameter is U instead of K in order to work around a compiler bug; see SPARK-4459
+ implicit val ctagK: ClassTag[U] = fakeClassTag
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(fakeClassTag)))
}
@@ -222,10 +226,11 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JIterable[T]] = {
- implicit val ctagK: ClassTag[K] = fakeClassTag
+ def groupBy[U](f: JFunction[T, U], numPartitions: Int): JavaPairRDD[U, JIterable[T]] = {
+ // The type parameter is U instead of K in order to work around a compiler bug; see SPARK-4459
+ implicit val ctagK: ClassTag[U] = fakeClassTag
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
- JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[K])))
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[U])))
}
/**
@@ -343,6 +348,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
+ /**
+ * Reduces the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree
+ * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]]
+ */
+ def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth)
+
+ /**
+ * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2.
+ */
+ def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2)
+
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
@@ -364,6 +382,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
combOp: JFunction2[U, U, U]): U =
rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U])
+ /**
+ * Aggregates the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree
+ * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]]
+ */
+ def treeAggregate[U](
+ zeroValue: U,
+ seqOp: JFunction2[U, T, U],
+ combOp: JFunction2[U, U, U],
+ depth: Int): U = {
+ rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U])
+ }
+
+ /**
+ * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2.
+ */
+ def treeAggregate[U](
+ zeroValue: U,
+ seqOp: JFunction2[U, T, U],
+ combOp: JFunction2[U, U, U]): U = {
+ treeAggregate(zeroValue, seqOp, combOp, 2)
+ }
+
/**
* Return the number of elements in the RDD.
*/
@@ -434,6 +476,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def first(): T = rdd.first()
+ /**
+ * @return true if and only if the RDD contains no elements at all. Note that an RDD
+ * may be empty even when it has at least 1 partition.
+ */
+ def isEmpty(): Boolean = rdd.isEmpty()
+
/**
* Save this RDD as a text file, using string representations of elements.
*/
@@ -459,8 +507,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Creates tuples of the elements in this RDD by applying `f`.
*/
- def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
- implicit val ctag: ClassTag[K] = fakeClassTag
+ def keyBy[U](f: JFunction[T, U]): JavaPairRDD[U, T] = {
+ // The type parameter is U instead of K in order to work around a compiler bug; see SPARK-4459
+ implicit val ctag: ClassTag[U] = fakeClassTag
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
@@ -493,9 +542,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the top K elements from this RDD as defined by
+ * Returns the top k (largest) elements from this RDD as defined by
* the specified Comparator[T].
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
@@ -507,9 +556,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the top K elements from this RDD using the
+ * Returns the top k (largest) elements from this RDD using the
* natural ordering for T.
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @return an array of top elements
*/
def top(num: Int): JList[T] = {
@@ -518,9 +567,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the first K elements from this RDD as defined by
+ * Returns the first k (smallest) elements from this RDD as defined by
* the specified Comparator[T] and maintains the order.
- * @param num the number of top elements to return
+ * @param num k, the number of elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
@@ -552,9 +601,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the first K elements from this RDD using the
+ * Returns the first k (smallest) elements from this RDD using the
* natural ordering for T while maintain the order.
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @return an array of top elements
*/
def takeOrdered(num: Int): JList[T] = {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 791d853a015a1..97f5c9f257e09 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -28,11 +28,13 @@ import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.input.PortableDataStream
import org.apache.hadoop.mapred.{InputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.spark._
-import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam}
+import org.apache.spark.AccumulatorParam._
+import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
@@ -40,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
/**
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
* [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones.
+ *
+ * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before
+ * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details.
*/
class JavaSparkContext(val sc: SparkContext)
extends JavaSparkContextVarargsWorkaround with Closeable {
@@ -103,6 +108,8 @@ class JavaSparkContext(val sc: SparkContext)
private[spark] val env = sc.env
+ def statusTracker = new JavaSparkStatusTracker(sc)
+
def isLocal: java.lang.Boolean = sc.isLocal
def sparkUser: String = sc.sparkUser
@@ -183,6 +190,8 @@ class JavaSparkContext(val sc: SparkContext)
def textFile(path: String, minPartitions: Int): JavaRDD[String] =
sc.textFile(path, minPartitions)
+
+
/**
* Read a directory of text files from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI. Each file is read as a single record and returned in a
@@ -196,7 +205,10 @@ class JavaSparkContext(val sc: SparkContext)
* hdfs://a-hdfs-path/part-nnnnn
* }}}
*
- * Do `JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")`,
+ * Do
+ * {{{
+ * JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")
+ * }}}
*
*
then `rdd` contains
* {{{
@@ -223,6 +235,84 @@ class JavaSparkContext(val sc: SparkContext)
def wholeTextFiles(path: String): JavaPairRDD[String, String] =
new JavaPairRDD(sc.wholeTextFiles(path))
+ /**
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ */
+ def binaryFiles(path: String, minPartitions: Int): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, minPartitions))
+
+ /**
+ * :: Experimental ::
+ *
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ */
+ @Experimental
+ def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions))
+
+ /**
+ * :: Experimental ::
+ *
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ @Experimental
+ def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = {
+ new JavaRDD(sc.binaryRecords(path, recordLength))
+ }
+
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala
new file mode 100644
index 0000000000000..3300cad9efbab
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext}
+
+/**
+ * Low-level status reporting APIs for monitoring job and stage progress.
+ *
+ * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should
+ * be prepared to handle empty / missing information. For example, a job's stage ids may be known
+ * but the status API may not have any information about the details of those stages, so
+ * `getStageInfo` could potentially return `null` for a valid stage id.
+ *
+ * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs
+ * will provide information for the last `spark.ui.retainedStages` stages and
+ * `spark.ui.retainedJobs` jobs.
+ *
+ * NOTE: this class's constructor should be considered private and may be subject to change.
+ */
+class JavaSparkStatusTracker private[spark] (sc: SparkContext) {
+
+ /**
+ * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then
+ * returns all known jobs that are not associated with a job group.
+ *
+ * The returned list may contain running, failed, and completed jobs, and may vary across
+ * invocations of this method. This method does not guarantee the order of the elements in
+ * its result.
+ */
+ def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.statusTracker.getJobIdsForGroup(jobGroup)
+
+ /**
+ * Returns an array containing the ids of all active stages.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveStageIds(): Array[Int] = sc.statusTracker.getActiveStageIds()
+
+ /**
+ * Returns an array containing the ids of all active jobs.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveJobIds(): Array[Int] = sc.statusTracker.getActiveJobIds()
+
+ /**
+ * Returns job information, or `null` if the job info could not be found or was garbage collected.
+ */
+ def getJobInfo(jobId: Int): SparkJobInfo = sc.statusTracker.getJobInfo(jobId).orNull
+
+ /**
+ * Returns stage information, or `null` if the stage info could not be found or was
+ * garbage collected.
+ */
+ def getStageInfo(stageId: Int): SparkStageInfo = sc.statusTracker.getStageInfo(stageId).orNull
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index b52d0a5028e84..71b26737b8c02 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,7 +19,8 @@ package org.apache.spark.api.java
import com.google.common.base.Optional
-import scala.collection.convert.Wrappers.MapWrapper
+import java.{util => ju}
+import scala.collection.mutable
private[spark] object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
@@ -32,7 +33,64 @@ private[spark] object JavaUtils {
def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
new SerializableMapWrapper(underlying)
+ // Implementation is copied from scala.collection.convert.Wrappers.MapWrapper,
+ // but implements java.io.Serializable. It can't just be subclassed to make it
+ // Serializable since the MapWrapper class has no no-arg constructor. This class
+ // doesn't need a no-arg constructor though.
class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
- extends MapWrapper(underlying) with java.io.Serializable
+ extends ju.AbstractMap[A, B] with java.io.Serializable { self =>
+ override def size = underlying.size
+
+ override def get(key: AnyRef): B = try {
+ underlying get key.asInstanceOf[A] match {
+ case None => null.asInstanceOf[B]
+ case Some(v) => v
+ }
+ } catch {
+ case ex: ClassCastException => null.asInstanceOf[B]
+ }
+
+ override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] {
+ def size = self.size
+
+ def iterator = new ju.Iterator[ju.Map.Entry[A, B]] {
+ val ui = underlying.iterator
+ var prev : Option[A] = None
+
+ def hasNext = ui.hasNext
+
+ def next() = {
+ val (k, v) = ui.next
+ prev = Some(k)
+ new ju.Map.Entry[A, B] {
+ import scala.util.hashing.byteswap32
+ def getKey = k
+ def getValue = v
+ def setValue(v1 : B) = self.put(k, v1)
+ override def hashCode = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16)
+ override def equals(other: Any) = other match {
+ case e: ju.Map.Entry[_, _] => k == e.getKey && v == e.getValue
+ case _ => false
+ }
+ }
+ }
+
+ def remove() {
+ prev match {
+ case Some(k) =>
+ underlying match {
+ case mm: mutable.Map[A, _] =>
+ mm remove k
+ prev = None
+ case _ =>
+ throw new UnsupportedOperationException("remove")
+ }
+ case _ =>
+ throw new IllegalStateException("next must be called at least once before remove")
+ }
+ }
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index 49dc95f349eac..c9181a29d4756 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -61,8 +61,7 @@ private[python] object Converter extends Logging {
* Other objects are passed through without conversion.
*/
private[python] class WritableToJavaConverter(
- conf: Broadcast[SerializableWritable[Configuration]],
- batchSize: Int) extends Converter[Any, Any] {
+ conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] {
/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
@@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter(
map.put(convertWritable(k), convertWritable(v))
}
map
- case w: Writable =>
- if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w
+ case w: Writable => WritableUtils.clone(w, conf.value.value)
case other => other
}
}
@@ -140,6 +138,11 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] {
mapWritable.put(convertToWritable(k), convertToWritable(v))
}
mapWritable
+ case array: Array[Any] => {
+ val arrayWriteable = new ArrayWritable(classOf[Writable])
+ arrayWriteable.set(array.map(convertToWritable(_)))
+ arrayWriteable
+ }
case other => throw new SparkException(
s"Data of type ${other.getClass.getName} cannot be used")
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 29ca751519abd..b89effc16d36d 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,22 +19,21 @@ package org.apache.spark.api.python
import java.io._
import java.net._
-import java.nio.charset.Charset
-import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
+import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
+
+import org.apache.spark.input.PortableDataStream
import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials
-import net.razorvine.pickle.{Pickler, Unpickler}
+import com.google.common.base.Charsets.UTF_8
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
import org.apache.spark._
-import org.apache.spark.SparkContext._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
@@ -47,7 +46,7 @@ private[spark] class PythonRDD(
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
+ broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
@@ -68,16 +67,16 @@ private[spark] class PythonRDD(
envVars += ("SPARK_REUSE_WORKER" -> "1")
}
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
+ // Whether is the worker released into idle pool
+ @volatile var released = false
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
- var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
- if (reuse_worker && complete_cleanly) {
- env.releasePythonWorker(pythonExec, envVars.toMap, worker)
- } else {
+ writerThread.join()
+ if (!reuse_worker || !released) {
try {
worker.close()
} catch {
@@ -125,15 +124,15 @@ private[spark] class PythonRDD(
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
- context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
- throw new PythonException(new String(obj, "utf-8"),
+ throw new PythonException(new String(obj, UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
@@ -145,7 +144,13 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
- complete_cleanly = true
+ // Check whether the worker is ready to be re-used.
+ if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
+ if (reuse_worker) {
+ env.releasePythonWorker(pythonExec, envVars.toMap, worker)
+ released = true
+ }
+ }
null
}
} catch {
@@ -154,6 +159,10 @@ private[spark] class PythonRDD(
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
+ case e: Exception if env.isStopped =>
+ logDebug("Exception thrown after context is stopped", e)
+ null // exit silently
+
case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
@@ -223,8 +232,7 @@ private[spark] class PythonRDD(
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
+ PythonRDD.writeUTF(broadcast.value.path, dataOut)
oldBids.add(broadcast.id)
}
}
@@ -235,6 +243,7 @@ private[spark] class PythonRDD(
// Data values
PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
@@ -306,10 +315,11 @@ private object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
+ val END_OF_STREAM = -4
+ val NULL = -5
}
private[spark] object PythonRDD extends Logging {
- val UTF8 = Charset.forName("UTF-8")
// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
@@ -360,56 +370,30 @@ private[spark] object PythonRDD extends Logging {
}
}
- def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
- val file = new DataInputStream(new FileInputStream(filename))
- try {
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- sc.broadcast(obj)
- } finally {
- file.close()
- }
+ def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = {
+ sc.broadcast(new PythonBroadcast(path))
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
- // The right way to implement this would be to use TypeTags to get the full
- // type of T. Since I don't want to introduce breaking changes throughout the
- // entire Spark API, I have to use this hacky approach:
- if (iter.hasNext) {
- val first = iter.next()
- val newIter = Seq(first).iterator ++ iter
- first match {
- case arr: Array[Byte] =>
- newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes =>
- dataOut.writeInt(bytes.length)
- dataOut.write(bytes)
- }
- case string: String =>
- newIter.asInstanceOf[Iterator[String]].foreach { str =>
- writeUTF(str, dataOut)
- }
- case pair: Tuple2[_, _] =>
- pair._1 match {
- case bytePair: Array[Byte] =>
- newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
- dataOut.writeInt(pair._1.length)
- dataOut.write(pair._1)
- dataOut.writeInt(pair._2.length)
- dataOut.write(pair._2)
- }
- case stringPair: String =>
- newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
- writeUTF(pair._1, dataOut)
- writeUTF(pair._2, dataOut)
- }
- case other =>
- throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
- }
- case other =>
- throw new SparkException("Unexpected element type " + first.getClass)
- }
+
+ def write(obj: Any): Unit = obj match {
+ case null =>
+ dataOut.writeInt(SpecialLengths.NULL)
+ case arr: Array[Byte] =>
+ dataOut.writeInt(arr.length)
+ dataOut.write(arr)
+ case str: String =>
+ writeUTF(str, dataOut)
+ case stream: PortableDataStream =>
+ write(stream.toArray())
+ case (key, value) =>
+ write(key)
+ write(value)
+ case other =>
+ throw new SparkException("Unexpected element type " + other.getClass)
}
+
+ iter.foreach(write)
}
/**
@@ -434,7 +418,7 @@ private[spark] object PythonRDD extends Logging {
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration()))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -460,7 +444,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -486,7 +470,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -529,7 +513,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -555,7 +539,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -577,7 +561,7 @@ private[spark] object PythonRDD extends Logging {
}
def writeUTF(str: String, dataOut: DataOutputStream) {
- val bytes = str.getBytes(UTF8)
+ val bytes = str.getBytes(UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
@@ -738,109 +722,11 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}
-
-
- /**
- * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
- */
- @deprecated("PySpark does not use it anymore", "1.1")
- def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- SerDeUtil.initialize()
- iter.flatMap { row =>
- unpickle.loads(row) match {
- // in case of objects are pickled in batch mode
- case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
- // not in batch mode
- case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
- }
- }
- }
- }
-
- /**
- * Convert an RDD of serialized Python tuple to Array (no recursive conversions).
- * It is only used by pyspark.sql.
- */
- def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {
-
- def toArray(obj: Any): Array[_] = {
- obj match {
- case objs: JArrayList[_] =>
- objs.toArray
- case obj if obj.getClass.isArray =>
- obj.asInstanceOf[Array[_]].toArray
- }
- }
-
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].map(toArray)
- } else {
- Seq(toArray(obj))
- }
- }
- }.toJavaRDD()
- }
-
- private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
- private val pickle = new Pickler()
- private var batch = 1
- private val buffer = new mutable.ArrayBuffer[Any]
-
- override def hasNext(): Boolean = iter.hasNext
-
- override def next(): Array[Byte] = {
- while (iter.hasNext && buffer.length < batch) {
- buffer += iter.next()
- }
- val bytes = pickle.dumps(buffer.toArray)
- val size = bytes.length
- // let 1M < size < 10M
- if (size < 1024 * 1024) {
- batch *= 2
- } else if (size > 1024 * 1024 * 10 && batch > 1) {
- batch /= 2
- }
- buffer.clear()
- bytes
- }
- }
-
- /**
- * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
- * PySpark.
- */
- def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
- jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
- }
-
- /**
- * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
- */
- def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
- pyRDD.rdd.mapPartitions { iter =>
- SerDeUtil.initialize()
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].asScala
- } else {
- Seq(obj)
- }
- }
- }.toJavaRDD()
- }
}
private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
- override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
+ override def call(arr: Array[Byte]) : String = new String(arr, UTF_8)
}
/**
@@ -895,3 +781,49 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
}
}
}
+
+/**
+ * An Wrapper for Python Broadcast, which is written into disk by Python. It also will
+ * write the data into disk after deserialization, then Python can read it from disks.
+ */
+private[spark] class PythonBroadcast(@transient var path: String) extends Serializable {
+
+ /**
+ * Read data from disks, then copy it to `out`
+ */
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
+ val in = new FileInputStream(new File(path))
+ try {
+ Utils.copyStream(in, out)
+ } finally {
+ in.close()
+ }
+ }
+
+ /**
+ * Write data into disk, using randomly generated name.
+ */
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ val dir = new File(Utils.getLocalDir(SparkEnv.get.conf))
+ val file = File.createTempFile("broadcast", "", dir)
+ path = file.getAbsolutePath
+ val out = new FileOutputStream(file)
+ try {
+ Utils.copyStream(in, out)
+ } finally {
+ out.close()
+ }
+ }
+
+ /**
+ * Delete the file once the object is GCed.
+ */
+ override def finalize() {
+ if (!path.isEmpty) {
+ val file = new File(path)
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index be5ebfa9219d3..acbaba6791850 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -17,11 +17,14 @@
package org.apache.spark.api.python
-import java.io.{File, InputStream, IOException, OutputStream}
+import java.io.{File}
+import java.util.{List => JList}
+import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkContext
+import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
private[spark] object PythonUtils {
/** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */
@@ -39,4 +42,15 @@ private[spark] object PythonUtils {
def mergePythonPaths(paths: String*): String = {
paths.filter(_ != "").mkString(File.pathSeparator)
}
+
+ def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = {
+ sc.parallelize(List("a", null, "b"))
+ }
+
+ /**
+ * Convert list of T into seq of T (for calling API with varargs)
+ */
+ def toSeq[T](cols: JList[T]): Seq[T] = {
+ cols.toList.toSeq
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index ebdc3533e0992..fb52a960e0765 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -18,8 +18,13 @@
package org.apache.spark.api.python
import java.nio.ByteOrder
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.spark.api.java.JavaRDD
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.Failure
import scala.util.Try
@@ -89,6 +94,76 @@ private[spark] object SerDeUtil extends Logging {
}
initialize()
+
+ /**
+ * Convert an RDD of Java objects to Array (no recursive conversions).
+ * It is only used by pyspark.sql.
+ */
+ def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
+ jrdd.rdd.map {
+ case objs: JArrayList[_] =>
+ objs.toArray
+ case obj if obj.getClass.isArray =>
+ obj.asInstanceOf[Array[_]].toArray
+ }.toJavaRDD()
+ }
+
+ /**
+ * Choose batch size based on size of objects
+ */
+ private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
+ private val pickle = new Pickler()
+ private var batch = 1
+ private val buffer = new mutable.ArrayBuffer[Any]
+
+ override def hasNext: Boolean = iter.hasNext
+
+ override def next(): Array[Byte] = {
+ while (iter.hasNext && buffer.length < batch) {
+ buffer += iter.next()
+ }
+ val bytes = pickle.dumps(buffer.toArray)
+ val size = bytes.length
+ // let 1M < size < 10M
+ if (size < 1024 * 1024) {
+ batch *= 2
+ } else if (size > 1024 * 1024 * 10 && batch > 1) {
+ batch /= 2
+ }
+ buffer.clear()
+ bytes
+ }
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize()
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj match {
+ case array: Array[Any] => array.toSeq
+ case _ => obj.asInstanceOf[JArrayList[_]].asScala
+ }
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
+ }
+
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val kt = Try {
@@ -127,18 +202,22 @@ private[spark] object SerDeUtil extends Logging {
* representation is serialized
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
- val (keyFailed, valueFailed) = checkPickle(rdd.first())
+ val (keyFailed, valueFailed) = rdd.take(1) match {
+ case Array() => (false, false)
+ case Array(first) => checkPickle(first)
+ }
+
rdd.mapPartitions { iter =>
- val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
- if (batchSize > 1) {
- cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
+ if (batchSize == 0) {
+ new AutoBatchedPickler(cleaned)
} else {
- cleaned.map(pickle.dumps(_))
+ val pickle = new Pickler
+ cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}
@@ -146,36 +225,24 @@ private[spark] object SerDeUtil extends Logging {
/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
- def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
+ def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
- Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
+ Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
- pyRDD.mapPartitions { iter =>
- initialize()
- val unpickle = new Unpickler
- val unpickled =
- if (batchSerialized) {
- iter.flatMap { batch =>
- unpickle.loads(batch) match {
- case objs: java.util.List[_] => collectionAsScalaIterable(objs)
- case other => throw new SparkException(
- s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
- }
- }
- } else {
- iter.map(unpickle.loads(_))
- }
- unpickled.map {
- case obj if isPair(obj) =>
- // we only accept (K, V)
- val arr = obj.asInstanceOf[Array[_]]
- (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
- case other => throw new SparkException(
- s"RDD element of type ${other.getClass.getName} cannot be used")
- }
+
+ val rdd = pythonToJava(pyRDD, batched).rdd
+ rdd.take(1) match {
+ case Array(obj) if isPair(obj) =>
+ // we only accept (K, V)
+ case Array() =>
+ // we also accept empty collections
+ case Array(other) => throw new SparkException(
+ s"RDD element of type ${other.getClass.getName} cannot be used")
+ }
+ rdd.map { obj =>
+ val arr = obj.asInstanceOf[Array[_]]
+ (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}
-
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
index d11db978b842e..cf289fb3ae39f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -18,7 +18,8 @@
package org.apache.spark.api.python
import java.io.{DataOutput, DataInput}
-import java.nio.charset.Charset
+
+import com.google.common.base.Charsets.UTF_8
import org.apache.hadoop.io._
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
@@ -106,7 +107,6 @@ private[python] class WritableToDoubleArrayConverter extends Converter[Any, Arra
* given directory (probably a temp directory)
*/
object WriteInputFormatTestDataGenerator {
- import SparkContext._
def main(args: Array[String]) {
val path = args(0)
@@ -136,7 +136,7 @@ object WriteInputFormatTestDataGenerator {
sc.parallelize(intKeys).saveAsSequenceFile(intPath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath)
- sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) }
+ sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) }
).saveAsSequenceFile(bytesPath)
val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false))
sc.parallelize(bools).saveAsSequenceFile(boolPath)
@@ -175,11 +175,11 @@ object WriteInputFormatTestDataGenerator {
// Create test data for arbitrary custom writable TestWritable
val testClass = Seq(
- ("1", TestWritable("test1", 123, 54.0)),
- ("2", TestWritable("test2", 456, 8762.3)),
- ("1", TestWritable("test3", 123, 423.1)),
- ("3", TestWritable("test56", 456, 423.5)),
- ("2", TestWritable("test2", 123, 5435.2))
+ ("1", TestWritable("test1", 1, 1.0)),
+ ("2", TestWritable("test2", 2, 2.3)),
+ ("3", TestWritable("test3", 3, 3.1)),
+ ("5", TestWritable("test56", 5, 5.5)),
+ ("4", TestWritable("test4", 4, 4.2))
)
val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
rdd.saveAsNewAPIHadoopFile(classPath,
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 15fd30e65761d..a5ea478f231d7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -20,6 +20,8 @@ package org.apache.spark.broadcast
import java.io.Serializable
import org.apache.spark.SparkException
+import org.apache.spark.Logging
+import org.apache.spark.util.Utils
import scala.reflect.ClassTag
@@ -37,7 +39,7 @@ import scala.reflect.ClassTag
*
* {{{
* scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
- * broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c)
+ * broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
*
* scala> broadcastVar.value
* res0: Array[Int] = Array(1, 2, 3)
@@ -52,7 +54,7 @@ import scala.reflect.ClassTag
* @param id A unique identifier for the broadcast variable.
* @tparam T Type of the data contained in the broadcast variable.
*/
-abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
+abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging {
/**
* Flag signifying whether the broadcast variable is valid
@@ -60,6 +62,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
*/
@volatile private var _isValid = true
+ private var _destroySite = ""
+
/** Get the broadcasted value. */
def value: T = {
assertValid()
@@ -84,13 +88,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
doUnpersist(blocking)
}
+
+ /**
+ * Destroy all data and metadata related to this broadcast variable. Use this with caution;
+ * once a broadcast variable has been destroyed, it cannot be used again.
+ * This method blocks until destroy has completed
+ */
+ def destroy() {
+ destroy(blocking = true)
+ }
+
/**
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
+ * @param blocking Whether to block until destroy has completed
*/
private[spark] def destroy(blocking: Boolean) {
assertValid()
_isValid = false
+ _destroySite = Utils.getCallSite().shortForm
+ logInfo("Destroying %s (from %s)".format(toString, _destroySite))
doDestroy(blocking)
}
@@ -124,7 +141,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
/** Check if this broadcast is valid. If not valid, exception is thrown. */
protected def assertValid() {
if (!_isValid) {
- throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
+ throw new SparkException(
+ "Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite))
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 4cd4f4f96fd16..1444c0dd3d2d6 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -72,13 +72,13 @@ private[spark] class HttpBroadcast[T: ClassTag](
}
/** Used by the JVM when serializing this object. */
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
assertValid()
out.defaultWriteObject()
}
/** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
@@ -151,9 +151,10 @@ private[broadcast] object HttpBroadcast extends Logging {
}
private def createServer(conf: SparkConf) {
- broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast")
val broadcastPort = conf.getInt("spark.broadcast.port", 0)
- server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
+ server =
+ new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
@@ -191,11 +192,14 @@ private[broadcast] object HttpBroadcast extends Logging {
logDebug("broadcast security enabled")
val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
uc = newuri.toURL.openConnection()
+ uc.setConnectTimeout(httpReadTimeout)
uc.setAllowUserInteraction(false)
} else {
logDebug("broadcast not using security")
uc = new URL(url).openConnection()
+ uc.setConnectTimeout(httpReadTimeout)
}
+ Utils.setupSecureURLConnection(uc, securityManager)
val in = {
uc.setReadTimeout(httpReadTimeout)
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 99af2e9608ea7..94142d33369c7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.{ByteBufferInputStream, Utils}
import org.apache.spark.util.io.ByteArrayChunkOutputStream
/**
@@ -56,11 +56,13 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
/**
- * Value of the broadcast object. On driver, this is set directly by the constructor.
- * On executors, this is reconstructed by [[readObject]], which builds this value by reading
- * blocks from the driver and/or other executors.
+ * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
+ * which builds this value by reading blocks from the driver and/or other executors.
+ *
+ * On the driver, if the value is required, it is read lazily from the block manager.
*/
- @transient private var _value: T = obj
+ @transient private lazy val _value: T = readBroadcastBlock()
+
/** The compression codec to use, or None if compression is disabled */
@transient private var compressionCodec: Option[CompressionCodec] = _
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
@@ -79,22 +81,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
private val broadcastId = BroadcastBlockId(id)
/** Total number of blocks this broadcast variable contains. */
- private val numBlocks: Int = writeBlocks()
+ private val numBlocks: Int = writeBlocks(obj)
- override protected def getValue() = _value
+ override protected def getValue() = {
+ _value
+ }
/**
* Divide the object into multiple blocks and put those blocks in the block manager.
- *
+ * @param value the object to divide
* @return number of blocks this broadcast variable is divided into
*/
- private def writeBlocks(): Int = {
+ private def writeBlocks(value: T): Int = {
// Store a copy of the broadcast variable in the driver so that tasks run on the driver
// do not create a duplicate copy of the broadcast variable's value.
- SparkEnv.get.blockManager.putSingle(broadcastId, _value, StorageLevel.MEMORY_AND_DISK,
+ SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK,
tellMaster = false)
val blocks =
- TorrentBroadcast.blockifyObject(_value, blockSize, SparkEnv.get.serializer, compressionCodec)
+ TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
blocks.zipWithIndex.foreach { case (block, i) =>
SparkEnv.get.blockManager.putBytes(
BroadcastBlockId(id, "piece" + i),
@@ -152,36 +156,35 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
/** Used by the JVM when serializing this object. */
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
assertValid()
out.defaultWriteObject()
}
- /** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
+ private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
setConf(SparkEnv.get.conf)
SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
- _value = x.asInstanceOf[T]
+ x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
- val start = System.nanoTime()
+ val startTimeMs = System.currentTimeMillis()
val blocks = readBlocks()
- val time = (System.nanoTime() - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
- _value =
- TorrentBroadcast.unBlockifyObject[T](blocks, SparkEnv.get.serializer, compressionCodec)
+ val obj = TorrentBroadcast.unBlockifyObject[T](
+ blocks, SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
- broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ obj
}
}
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
index 65a1a8fd7e929..ae55b4ff40b74 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
@@ -28,5 +28,14 @@ private[spark] class ApplicationDescription(
val user = System.getProperty("user.name", "")
+ def copy(
+ name: String = name,
+ maxCores: Option[Int] = maxCores,
+ memoryPerSlave: Int = memoryPerSlave,
+ command: Command = command,
+ appUiUrl: String = appUiUrl,
+ eventLogDir: Option[String] = eventLogDir): ApplicationDescription =
+ new ApplicationDescription(name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir)
+
override def toString: String = "ApplicationDescription(" + name + ")"
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index f2687ce6b42b4..38b3da0b13756 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -39,7 +39,8 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
val timeout = AkkaUtils.askTimeout(conf)
override def preStart() = {
- masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master))
+ masterActor = context.actorSelection(
+ Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system)))
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
@@ -160,6 +161,8 @@ object Client {
val (actorSystem, _) = AkkaUtils.createActorSystem(
"driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
+ // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely
+ Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(actorSystem))
actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
actorSystem.awaitTermination()
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 39150deab863c..e5873ce724b9f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -17,11 +17,13 @@
package org.apache.spark.deploy
+import java.net.{URI, URISyntaxException}
+
import scala.collection.mutable.ListBuffer
import org.apache.log4j.Level
-import org.apache.spark.util.MemoryParam
+import org.apache.spark.util.{IntParam, MemoryParam}
/**
* Command-line parser for the driver client.
@@ -49,8 +51,8 @@ private[spark] class ClientArguments(args: Array[String]) {
parse(args.toList)
def parse(args: List[String]): Unit = args match {
- case ("--cores" | "-c") :: value :: tail =>
- cores = value.toInt
+ case ("--cores" | "-c") :: IntParam(value) :: tail =>
+ cores = value
parse(tail)
case ("--memory" | "-m") :: MemoryParam(value) :: tail =>
@@ -73,7 +75,8 @@ private[spark] class ClientArguments(args: Array[String]) {
if (!ClientArguments.isValidJarUrl(_jarUrl)) {
println(s"Jar url '${_jarUrl}' is not in valid format.")
- println(s"Must be a jar file path in URL format (e.g. hdfs://XX.jar, file://XX.jar)")
+ println(s"Must be a jar file path in URL format " +
+ "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)")
printUsageAndExit(-1)
}
@@ -114,5 +117,12 @@ private[spark] class ClientArguments(args: Array[String]) {
}
object ClientArguments {
- def isValidJarUrl(s: String): Boolean = s.matches("(.+):(.+)jar")
+ def isValidJarUrl(s: String): Boolean = {
+ try {
+ val uri = new URI(s)
+ uri.getScheme != null && uri.getPath != null && uri.getPath.endsWith(".jar")
+ } catch {
+ case _: URISyntaxException => false
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index b9dd8557ee904..243d8edb72ed3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -88,10 +88,14 @@ private[deploy] object DeployMessages {
case class KillDriver(driverId: String) extends DeployMessage
+ case class ApplicationFinished(id: String)
+
// Worker internal
case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders
+ case object ReregisterWithMaster // used when a worker attempts to reconnect to a master
+
// AppClient to Master
case class RegisterApplication(appDescription: ApplicationDescription)
@@ -173,4 +177,5 @@ private[deploy] object DeployMessages {
// Liveness checks in various places
case object SendHeartbeat
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala
index 58c95dc4f9116..b056a19ce6598 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala
@@ -25,5 +25,13 @@ private[spark] class DriverDescription(
val command: Command)
extends Serializable {
+ def copy(
+ jarUrl: String = jarUrl,
+ mem: Int = mem,
+ cores: Int = cores,
+ supervise: Boolean = supervise,
+ command: Command = command): DriverDescription =
+ new DriverDescription(jarUrl, mem, cores, supervise, command)
+
override def toString: String = s"DriverDescription (${command.mainClass})"
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index af94b05ce3847..53e18c4bcec23 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -26,7 +26,7 @@ import org.apache.spark.api.python.PythonUtils
import org.apache.spark.util.{RedirectThread, Utils}
/**
- * A main class used by spark-submit to launch Python applications. It executes python as a
+ * A main class used to launch Python applications. It executes python as a
* subprocess and then has it connect back to the JVM to access system properties, etc.
*/
object PythonRunner {
@@ -87,8 +87,8 @@ object PythonRunner {
// Strip the URI scheme from the path
formattedPath =
new URI(formattedPath).getScheme match {
- case Utils.windowsDrive(d) if windows => formattedPath
case null => formattedPath
+ case Utils.windowsDrive(d) if windows => formattedPath
case _ => new URI(formattedPath).getPath
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index fe0ad9ebbca12..d68854214ef06 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,15 +17,20 @@
package org.apache.spark.deploy
+import java.lang.reflect.Method
import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.FileSystem.Statistics
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
import scala.collection.JavaConversions._
@@ -121,6 +126,71 @@ class SparkHadoopUtil extends Logging {
UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename)
}
+ /**
+ * Returns a function that can be called to find Hadoop FileSystem bytes read. If
+ * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
+ * return the bytes read on r since t. Reflection is required because thread-level FileSystem
+ * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
+ * Returns None if the required method can't be found.
+ */
+ private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = {
+ try {
+ val threadStats = getFileSystemThreadStatistics()
+ val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
+ val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
+ val baselineBytesRead = f()
+ Some(() => f() - baselineBytesRead)
+ } catch {
+ case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => {
+ logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e)
+ None
+ }
+ }
+ }
+
+ /**
+ * Returns a function that can be called to find Hadoop FileSystem bytes written. If
+ * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will
+ * return the bytes written on r since t. Reflection is required because thread-level FileSystem
+ * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
+ * Returns None if the required method can't be found.
+ */
+ private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = {
+ try {
+ val threadStats = getFileSystemThreadStatistics()
+ val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
+ val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
+ val baselineBytesWritten = f()
+ Some(() => f() - baselineBytesWritten)
+ } catch {
+ case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => {
+ logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e)
+ None
+ }
+ }
+ }
+
+ private def getFileSystemThreadStatistics(): Seq[AnyRef] = {
+ val stats = FileSystem.getAllStatistics()
+ stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
+ }
+
+ private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
+ val statisticsDataClass =
+ Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
+ statisticsDataClass.getDeclaredMethod(methodName)
+ }
+
+ /**
+ * Using reflection to get the Configuration from JobContext/TaskAttemptContext. If we directly
+ * call `JobContext/TaskAttemptContext.getConfiguration`, it will generate different byte codes
+ * for Hadoop 1.+ and Hadoop 2.+ because JobContext/TaskAttemptContext is class in Hadoop 1.+
+ * while it's interface in Hadoop 2.+.
+ */
+ def getConfigurationFromJobContext(context: JobContext): Configuration = {
+ val method = context.getClass.getMethod("getConfiguration")
+ method.invoke(context).asInstanceOf[Configuration]
+ }
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index f97bf67fa5a3b..8bbfcd2997dc6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -23,6 +23,19 @@ import java.net.URL
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
+import org.apache.hadoop.fs.Path
+
+import org.apache.ivy.Ivy
+import org.apache.ivy.core.LogOptions
+import org.apache.ivy.core.module.descriptor.{DefaultExcludeRule, DefaultDependencyDescriptor, DefaultModuleDescriptor}
+import org.apache.ivy.core.module.id.{ModuleId, ArtifactId, ModuleRevisionId}
+import org.apache.ivy.core.report.ResolveReport
+import org.apache.ivy.core.resolve.{IvyNode, ResolveOptions}
+import org.apache.ivy.core.retrieve.RetrieveOptions
+import org.apache.ivy.core.settings.IvySettings
+import org.apache.ivy.plugins.matcher.GlobPatternMatcher
+import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver}
+
import org.apache.spark.executor.ExecutorURLClassLoader
import org.apache.spark.util.Utils
@@ -134,19 +147,38 @@ object SparkSubmit {
}
}
+ val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
+
+ // Require all python files to be local, so we can add them to the PYTHONPATH
+ // In YARN cluster mode, python files are distributed as regular files, which can be non-local
+ if (args.isPython && !isYarnCluster) {
+ if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) {
+ printErrorAndExit(s"Only local python files are supported: $args.primaryResource")
+ }
+ val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",")
+ if (nonLocalPyFiles.nonEmpty) {
+ printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles")
+ }
+ }
+
// The following modes are not supported or applicable
(clusterManager, deployMode) match {
case (MESOS, CLUSTER) =>
printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.")
- case (_, CLUSTER) if args.isPython =>
- printErrorAndExit("Cluster deploy mode is currently not supported for python applications.")
+ case (STANDALONE, CLUSTER) if args.isPython =>
+ printErrorAndExit("Cluster deploy mode is currently not supported for python " +
+ "applications on standalone clusters.")
case (_, CLUSTER) if isShell(args.primaryResource) =>
printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.")
+ case (_, CLUSTER) if isSqlShell(args.mainClass) =>
+ printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.")
+ case (_, CLUSTER) if isThriftServer(args.mainClass) =>
+ printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.")
case _ =>
}
// If we're running a python app, set the main class to our specific python runner
- if (args.isPython) {
+ if (args.isPython && deployMode == CLIENT) {
if (args.primaryResource == PYSPARK_SHELL) {
args.mainClass = "py4j.GatewayServer"
args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0")
@@ -158,13 +190,33 @@ object SparkSubmit {
args.files = mergeFileLists(args.files, args.primaryResource)
}
args.files = mergeFileLists(args.files, args.pyFiles)
- // Format python file paths properly before adding them to the PYTHONPATH
- sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",")
+ if (args.pyFiles != null) {
+ sysProps("spark.submit.pyFiles") = args.pyFiles
+ }
+ }
+
+ // In yarn-cluster mode for a python app, add primary resource and pyFiles to files
+ // that can be distributed with the job
+ if (args.isPython && isYarnCluster) {
+ args.files = mergeFileLists(args.files, args.primaryResource)
+ args.files = mergeFileLists(args.files, args.pyFiles)
}
// Special flag to avoid deprecation warnings at the client
sysProps("SPARK_SUBMIT") = "true"
+ // Resolve maven dependencies if there are any and add classpath to jars
+ val resolvedMavenCoordinates =
+ SparkSubmitUtils.resolveMavenCoordinates(
+ args.packages, Option(args.repositories), Option(args.ivyRepoPath))
+ if (!resolvedMavenCoordinates.trim.isEmpty) {
+ if (args.jars == null || args.jars.trim.isEmpty) {
+ args.jars = resolvedMavenCoordinates
+ } else {
+ args.jars += s",$resolvedMavenCoordinates"
+ }
+ }
+
// A list of rules to map each argument to system properties or command-line options in
// each deploy mode; we iterate through these below
val options = List[OptionAssigner](
@@ -173,6 +225,7 @@ object SparkSubmit {
OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"),
OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"),
OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"),
+ OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"),
OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT,
sysProp = "spark.driver.memory"),
OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
@@ -184,6 +237,7 @@ object SparkSubmit {
// Standalone cluster only
OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"),
+ OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"),
OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"),
OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"),
@@ -197,6 +251,7 @@ object SparkSubmit {
// Yarn cluster only
OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"),
OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"),
+ OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"),
OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"),
OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"),
OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"),
@@ -239,7 +294,6 @@ object SparkSubmit {
// Add the application jar automatically so the user doesn't have to call sc.addJar
// For YARN cluster mode, the jar is already distributed on each node as "app.jar"
// For python files, the primary resource is already distributed as a regular file
- val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
if (!isYarnCluster && !args.isPython) {
var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty)
if (isUserJar(args.primaryResource)) {
@@ -264,24 +318,58 @@ object SparkSubmit {
// In yarn-cluster mode, use yarn.Client as a wrapper around the user class
if (isYarnCluster) {
childMainClass = "org.apache.spark.deploy.yarn.Client"
- if (args.primaryResource != SPARK_INTERNAL) {
- childArgs += ("--jar", args.primaryResource)
+ if (args.isPython) {
+ val mainPyFile = new Path(args.primaryResource).getName
+ childArgs += ("--primary-py-file", mainPyFile)
+ if (args.pyFiles != null) {
+ // These files will be distributed to each machine's working directory, so strip the
+ // path prefix
+ val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",")
+ childArgs += ("--py-files", pyFilesNames)
+ }
+ childArgs += ("--class", "org.apache.spark.deploy.PythonRunner")
+ } else {
+ if (args.primaryResource != SPARK_INTERNAL) {
+ childArgs += ("--jar", args.primaryResource)
+ }
+ childArgs += ("--class", args.mainClass)
}
- childArgs += ("--class", args.mainClass)
if (args.childArgs != null) {
args.childArgs.foreach { arg => childArgs += ("--arg", arg) }
}
}
- // Properties given with --conf are superceded by other options, but take precedence over
- // properties in the defaults file.
+ // Load any properties specified through --conf and the default properties file
for ((k, v) <- args.sparkProperties) {
sysProps.getOrElseUpdate(k, v)
}
- // Read from default spark properties, if any
- for ((k, v) <- args.defaultSparkProperties) {
- sysProps.getOrElseUpdate(k, v)
+ // Ignore invalid spark.driver.host in cluster modes.
+ if (deployMode == CLUSTER) {
+ sysProps -= ("spark.driver.host")
+ }
+
+ // Resolve paths in certain spark properties
+ val pathConfigs = Seq(
+ "spark.jars",
+ "spark.files",
+ "spark.yarn.jar",
+ "spark.yarn.dist.files",
+ "spark.yarn.dist.archives")
+ pathConfigs.foreach { config =>
+ // Replace old URIs with resolved URIs, if they exist
+ sysProps.get(config).foreach { oldValue =>
+ sysProps(config) = Utils.resolveURIs(oldValue)
+ }
+ }
+
+ // Resolve and format python file paths properly before adding them to the PYTHONPATH.
+ // The resolving part is redundant in the case of --py-files, but necessary if the user
+ // explicitly sets `spark.submit.pyFiles` in his/her default properties file.
+ sysProps.get("spark.submit.pyFiles").foreach { pyFiles =>
+ val resolvedPyFiles = Utils.resolveURIs(pyFiles)
+ val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",")
+ sysProps("spark.submit.pyFiles") = formattedPyFiles
}
(childArgs, childClasspath, sysProps, childMainClass)
@@ -321,12 +409,17 @@ object SparkSubmit {
case e: ClassNotFoundException =>
e.printStackTrace(printStream)
if (childMainClass.contains("thriftserver")) {
- println(s"Failed to load main class $childMainClass.")
- println("You need to build Spark with -Phive.")
+ printStream.println(s"Failed to load main class $childMainClass.")
+ printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.")
}
System.exit(CLASS_NOT_FOUND_EXIT_STATUS)
}
+ // SPARK-4170
+ if (classOf[scala.App].isAssignableFrom(mainClass)) {
+ printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.")
+ }
+
val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass)
if (!Modifier.isStatic(mainMethod.getModifiers)) {
throw new IllegalStateException("The main method in the given main class must be static")
@@ -370,6 +463,20 @@ object SparkSubmit {
primaryResource == SPARK_SHELL || primaryResource == PYSPARK_SHELL
}
+ /**
+ * Return whether the given main class represents a sql shell.
+ */
+ private[spark] def isSqlShell(mainClass: String): Boolean = {
+ mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
+ }
+
+ /**
+ * Return whether the given main class represents a thrift server.
+ */
+ private[spark] def isThriftServer(mainClass: String): Boolean = {
+ mainClass == "org.apache.spark.sql.hive.thriftserver.HiveThriftServer2"
+ }
+
/**
* Return whether the given primary resource requires running python.
*/
@@ -393,6 +500,194 @@ object SparkSubmit {
}
}
+/** Provides utility functions to be used inside SparkSubmit. */
+private[spark] object SparkSubmitUtils {
+
+ // Exposed for testing
+ private[spark] var printStream = SparkSubmit.printStream
+
+ /**
+ * Represents a Maven Coordinate
+ * @param groupId the groupId of the coordinate
+ * @param artifactId the artifactId of the coordinate
+ * @param version the version of the coordinate
+ */
+ private[spark] case class MavenCoordinate(groupId: String, artifactId: String, version: String)
+
+/**
+ * Extracts maven coordinates from a comma-delimited string
+ * @param coordinates Comma-delimited string of maven coordinates
+ * @return Sequence of Maven coordinates
+ */
+ private[spark] def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = {
+ coordinates.split(",").map { p =>
+ val splits = p.split(":")
+ require(splits.length == 3, s"Provided Maven Coordinates must be in the form " +
+ s"'groupId:artifactId:version'. The coordinate provided is: $p")
+ require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " +
+ s"be whitespace. The groupId provided is: ${splits(0)}")
+ require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " +
+ s"be whitespace. The artifactId provided is: ${splits(1)}")
+ require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " +
+ s"be whitespace. The version provided is: ${splits(2)}")
+ new MavenCoordinate(splits(0), splits(1), splits(2))
+ }
+ }
+
+ /**
+ * Extracts maven coordinates from a comma-delimited string
+ * @param remoteRepos Comma-delimited string of remote repositories
+ * @return A ChainResolver used by Ivy to search for and resolve dependencies.
+ */
+ private[spark] def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = {
+ // We need a chain resolver if we want to check multiple repositories
+ val cr = new ChainResolver
+ cr.setName("list")
+
+ // the biblio resolver resolves POM declared dependencies
+ val br: IBiblioResolver = new IBiblioResolver
+ br.setM2compatible(true)
+ br.setUsepoms(true)
+ br.setName("central")
+ cr.add(br)
+
+ val repositoryList = remoteRepos.getOrElse("")
+ // add any other remote repositories other than maven central
+ if (repositoryList.trim.nonEmpty) {
+ repositoryList.split(",").zipWithIndex.foreach { case (repo, i) =>
+ val brr: IBiblioResolver = new IBiblioResolver
+ brr.setM2compatible(true)
+ brr.setUsepoms(true)
+ brr.setRoot(repo)
+ brr.setName(s"repo-${i + 1}")
+ cr.add(brr)
+ printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}")
+ }
+ }
+ cr
+ }
+
+ /**
+ * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath
+ * (will append to jars in SparkSubmit). The name of the jar is given
+ * after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well.
+ * @param artifacts Sequence of dependencies that were resolved and retrieved
+ * @param cacheDirectory directory where jars are cached
+ * @return a comma-delimited list of paths for the dependencies
+ */
+ private[spark] def resolveDependencyPaths(
+ artifacts: Array[AnyRef],
+ cacheDirectory: File): String = {
+ artifacts.map { artifactInfo =>
+ val artifactString = artifactInfo.toString
+ val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1)
+ cacheDirectory.getAbsolutePath + File.separator +
+ jarName.substring(0, jarName.lastIndexOf(".jar") + 4)
+ }.mkString(",")
+ }
+
+ /** Adds the given maven coordinates to Ivy's module descriptor. */
+ private[spark] def addDependenciesToIvy(
+ md: DefaultModuleDescriptor,
+ artifacts: Seq[MavenCoordinate],
+ ivyConfName: String): Unit = {
+ artifacts.foreach { mvn =>
+ val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version)
+ val dd = new DefaultDependencyDescriptor(ri, false, false)
+ dd.addDependencyConfiguration(ivyConfName, ivyConfName)
+ printStream.println(s"${dd.getDependencyId} added as a dependency")
+ md.addDependency(dd)
+ }
+ }
+
+ /** A nice function to use in tests as well. Values are dummy strings. */
+ private[spark] def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance(
+ ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0"))
+
+ /**
+ * Resolves any dependencies that were supplied through maven coordinates
+ * @param coordinates Comma-delimited string of maven coordinates
+ * @param remoteRepos Comma-delimited string of remote repositories other than maven central
+ * @param ivyPath The path to the local ivy repository
+ * @return The comma-delimited path to the jars of the given maven artifacts including their
+ * transitive dependencies
+ */
+ private[spark] def resolveMavenCoordinates(
+ coordinates: String,
+ remoteRepos: Option[String],
+ ivyPath: Option[String],
+ isTest: Boolean = false): String = {
+ if (coordinates == null || coordinates.trim.isEmpty) {
+ ""
+ } else {
+ val artifacts = extractMavenCoordinates(coordinates)
+ // Default configuration name for ivy
+ val ivyConfName = "default"
+ // set ivy settings for location of cache
+ val ivySettings: IvySettings = new IvySettings
+ // Directories for caching downloads through ivy and storing the jars when maven coordinates
+ // are supplied to spark-submit
+ val alternateIvyCache = ivyPath.getOrElse("")
+ val packagesDirectory: File =
+ if (alternateIvyCache.trim.isEmpty) {
+ new File(ivySettings.getDefaultIvyUserDir, "jars")
+ } else {
+ ivySettings.setDefaultCache(new File(alternateIvyCache, "cache"))
+ new File(alternateIvyCache, "jars")
+ }
+ printStream.println(
+ s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}")
+ printStream.println(s"The jars for the packages stored in: $packagesDirectory")
+ // create a pattern matcher
+ ivySettings.addMatcher(new GlobPatternMatcher)
+ // create the dependency resolvers
+ val repoResolver = createRepoResolvers(remoteRepos)
+ ivySettings.addResolver(repoResolver)
+ ivySettings.setDefaultResolver(repoResolver.getName)
+
+ val ivy = Ivy.newInstance(ivySettings)
+ // Set resolve options to download transitive dependencies as well
+ val resolveOptions = new ResolveOptions
+ resolveOptions.setTransitive(true)
+ val retrieveOptions = new RetrieveOptions
+ // Turn downloading and logging off for testing
+ if (isTest) {
+ resolveOptions.setDownload(false)
+ resolveOptions.setLog(LogOptions.LOG_QUIET)
+ retrieveOptions.setLog(LogOptions.LOG_QUIET)
+ } else {
+ resolveOptions.setDownload(true)
+ }
+
+ // A Module descriptor must be specified. Entries are dummy strings
+ val md = getModuleDescriptor
+ md.setDefaultConf(ivyConfName)
+
+ // Add an exclusion rule for Spark
+ val sparkArtifacts = new ArtifactId(new ModuleId("org.apache.spark", "*"), "*", "*", "*")
+ val sparkDependencyExcludeRule =
+ new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null)
+ sparkDependencyExcludeRule.addConfiguration(ivyConfName)
+
+ // Exclude any Spark dependencies, and add all supplied maven artifacts as dependencies
+ md.addExcludeRule(sparkDependencyExcludeRule)
+ addDependenciesToIvy(md, artifacts, ivyConfName)
+
+ // resolve dependencies
+ val rr: ResolveReport = ivy.resolve(md, resolveOptions)
+ if (rr.hasError) {
+ throw new RuntimeException(rr.getAllProblemMessages.toString)
+ }
+ // retrieve all resolved dependencies
+ ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId,
+ packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]",
+ retrieveOptions.setConfs(Array(ivyConfName)))
+
+ resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory)
+ }
+ }
+}
+
/**
* Provides an indirection layer for passing arguments as system properties or flags to
* the user's driver program or to downstream launcher tools.
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 72a452e0aefb5..5cadc534f4baa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -17,9 +17,9 @@
package org.apache.spark.deploy
+import java.net.URI
import java.util.jar.JarFile
-import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.spark.util.Utils
@@ -50,6 +50,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var name: String = null
var childArgs: ArrayBuffer[String] = new ArrayBuffer[String]()
var jars: String = null
+ var packages: String = null
+ var repositories: String = null
+ var ivyRepoPath: String = null
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
@@ -72,57 +75,92 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
defaultProperties
}
- // Respect SPARK_*_MEMORY for cluster mode
- driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull
- executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull
-
+ // Set parameters from command line arguments
parseOpts(args.toList)
- mergeSparkProperties()
+ // Populate `sparkProperties` map from properties file
+ mergeDefaultSparkProperties()
+ // Use `sparkProperties` map along with env vars to fill in any missing parameters
+ loadEnvironmentArguments()
+
checkRequiredArguments()
/**
- * Fill in any undefined values based on the default properties file or options passed in through
- * the '--conf' flag.
+ * Merge values from the default properties file with those specified through --conf.
+ * When this is called, `sparkProperties` is already filled with configs from the latter.
*/
- private def mergeSparkProperties(): Unit = {
+ private def mergeDefaultSparkProperties(): Unit = {
// Use common defaults file, if not specified by user
propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env))
+ // Honor --conf before the defaults file
+ defaultSparkProperties.foreach { case (k, v) =>
+ if (!sparkProperties.contains(k)) {
+ sparkProperties(k) = v
+ }
+ }
+ }
- val properties = HashMap[String, String]()
- properties.putAll(defaultSparkProperties)
- properties.putAll(sparkProperties)
-
- // Use properties file as fallback for values which have a direct analog to
- // arguments in this script.
- master = Option(master).orElse(properties.get("spark.master")).orNull
- executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull
- executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull
+ /**
+ * Load arguments from environment variables, Spark properties etc.
+ */
+ private def loadEnvironmentArguments(): Unit = {
+ master = Option(master)
+ .orElse(sparkProperties.get("spark.master"))
+ .orElse(env.get("MASTER"))
+ .orNull
+ driverMemory = Option(driverMemory)
+ .orElse(sparkProperties.get("spark.driver.memory"))
+ .orElse(env.get("SPARK_DRIVER_MEMORY"))
+ .orNull
+ driverCores = Option(driverCores)
+ .orElse(sparkProperties.get("spark.driver.cores"))
+ .orNull
+ executorMemory = Option(executorMemory)
+ .orElse(sparkProperties.get("spark.executor.memory"))
+ .orElse(env.get("SPARK_EXECUTOR_MEMORY"))
+ .orNull
+ executorCores = Option(executorCores)
+ .orElse(sparkProperties.get("spark.executor.cores"))
+ .orNull
totalExecutorCores = Option(totalExecutorCores)
- .orElse(properties.get("spark.cores.max"))
+ .orElse(sparkProperties.get("spark.cores.max"))
.orNull
- name = Option(name).orElse(properties.get("spark.app.name")).orNull
- jars = Option(jars).orElse(properties.get("spark.jars")).orNull
-
- // This supports env vars in older versions of Spark
- master = Option(master).orElse(env.get("MASTER")).orNull
+ name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull
+ jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull
+ ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
+ numExecutors = Option(numExecutors)
+ .getOrElse(sparkProperties.get("spark.executor.instances").orNull)
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && primaryResource != null) {
- try {
- val jar = new JarFile(primaryResource)
- // Note that this might still return null if no main-class is set; we catch that later
- mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class")
- } catch {
- case e: Exception =>
- SparkSubmit.printErrorAndExit("Cannot load main class from JAR: " + primaryResource)
- return
+ val uri = new URI(primaryResource)
+ val uriScheme = uri.getScheme()
+
+ uriScheme match {
+ case "file" =>
+ try {
+ val jar = new JarFile(uri.getPath)
+ // Note that this might still return null if no main-class is set; we catch that later
+ mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class")
+ } catch {
+ case e: Exception =>
+ SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource")
+ }
+ case _ =>
+ SparkSubmit.printErrorAndExit(
+ s"Cannot load main class from JAR $primaryResource with URI $uriScheme. " +
+ "Please specify a class through --class.")
}
}
// Global defaults. These should be keep to minimum to avoid confusing behavior.
master = Option(master).getOrElse("local[*]")
+ // In YARN mode, app name can be set via SPARK_YARN_APP_NAME (see SPARK-5222)
+ if (master.startsWith("yarn")) {
+ name = Option(name).orElse(env.get("SPARK_YARN_APP_NAME")).orNull
+ }
+
// Set name from main class if not given
name = Option(name).orElse(Option(mainClass)).orNull
if (name == null && primaryResource != null) {
@@ -131,7 +169,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
/** Ensure that required fields exists. Call this only once all defaults are loaded. */
- private def checkRequiredArguments() = {
+ private def checkRequiredArguments(): Unit = {
if (args.length == 0) {
printUsageAndExit(-1)
}
@@ -145,18 +183,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script")
}
- // Require all python files to be local, so we can add them to the PYTHONPATH
- if (isPython) {
- if (Utils.nonLocalPaths(primaryResource).nonEmpty) {
- SparkSubmit.printErrorAndExit(s"Only local python files are supported: $primaryResource")
- }
- val nonLocalPyFiles = Utils.nonLocalPaths(pyFiles).mkString(",")
- if (nonLocalPyFiles.nonEmpty) {
- SparkSubmit.printErrorAndExit(
- s"Only local additional python files are supported: $nonLocalPyFiles")
- }
- }
-
if (master.startsWith("yarn")) {
val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR")
if (!hasHadoopEnv && !Utils.isTesting) {
@@ -166,7 +192,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
- override def toString = {
+ override def toString = {
s"""Parsed arguments:
| master $master
| deployMode $deployMode
@@ -174,7 +200,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| executorCores $executorCores
| totalExecutorCores $totalExecutorCores
| propertiesFile $propertiesFile
- | extraSparkProperties $sparkProperties
| driverMemory $driverMemory
| driverCores $driverCores
| driverExtraClassPath $driverExtraClassPath
@@ -191,14 +216,20 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| name $name
| childArgs [${childArgs.mkString(" ")}]
| jars $jars
+ | packages $packages
+ | repositories $repositories
| verbose $verbose
|
- |Default properties from $propertiesFile:
- |${defaultSparkProperties.mkString(" ", "\n ", "\n")}
+ |Spark properties used, including those specified through
+ | --conf and those from the properties file $propertiesFile:
+ |${sparkProperties.mkString(" ", "\n ", "\n")}
""".stripMargin
}
- /** Fill in values by parsing user options. */
+ /**
+ * Fill in values by parsing user options.
+ * NOTE: Any changes here must be reflected in YarnClientSchedulerBackend.
+ */
private def parseOpts(opts: Seq[String]): Unit = {
val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r
@@ -293,6 +324,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
jars = Utils.resolveURIs(value)
parse(tail)
+ case ("--packages") :: value :: tail =>
+ packages = value
+ parse(tail)
+
+ case ("--repositories") :: value :: tail =>
+ repositories = value
+ parse(tail)
+
case ("--conf" | "-c") :: value :: tail =>
value.split("=", 2).toSeq match {
case Seq(k, v) => sparkProperties(k) = v
@@ -327,7 +366,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
- private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = {
val outStream = SparkSubmit.printStream
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
@@ -343,6 +382,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| --name NAME A name of your application.
| --jars JARS Comma-separated list of local jars to include on the driver
| and executor classpaths.
+ | --packages Comma-separated list of maven coordinates of jars to include
+ | on the driver and executor classpaths. Will search the local
+ | maven repo, then maven central and any additional remote
+ | repositories given by --repositories. The format for the
+ | coordinates should be groupId:artifactId:version.
+ | --repositories Comma-separated list of additional remote repositories to
+ | search for the maven coordinates given with --packages.
| --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place
| on the PYTHONPATH for Python apps.
| --files FILES Comma-separated list of files to be placed in the working
@@ -372,11 +418,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| --total-executor-cores NUM Total cores for all executors.
|
| YARN-only:
+ | --driver-cores NUM Number of cores used by the driver, only in cluster mode
+ | (Default: 1).
| --executor-cores NUM Number of cores per executor (Default: 1).
| --queue QUEUE_NAME The YARN queue to submit to (Default: "default").
| --num-executors NUM Number of executors to launch (Default: 2).
| --archives ARCHIVES Comma separated list of archives to be extracted into the
- | working directory of each executor.""".stripMargin
+ | working directory of each executor.
+ """.stripMargin
)
SparkSubmit.exitFn()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
index 0125330589da5..2eab9981845e8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
@@ -82,17 +82,8 @@ private[spark] object SparkSubmitDriverBootstrapper {
.orElse(confDriverMemory)
.getOrElse(defaultDriverMemory)
- val newLibraryPath =
- if (submitLibraryPath.isDefined) {
- // SPARK_SUBMIT_LIBRARY_PATH is already captured in JAVA_OPTS
- ""
- } else {
- confLibraryPath.map("-Djava.library.path=" + _).getOrElse("")
- }
-
val newClasspath =
if (submitClasspath.isDefined) {
- // SPARK_SUBMIT_CLASSPATH is already captured in CLASSPATH
classpath
} else {
classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("")
@@ -114,7 +105,6 @@ private[spark] object SparkSubmitDriverBootstrapper {
val command: Seq[String] =
Seq(runner) ++
Seq("-cp", newClasspath) ++
- Seq(newLibraryPath) ++
filteredJavaOpts ++
Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++
Seq("org.apache.spark.deploy.SparkSubmit") ++
@@ -130,8 +120,25 @@ private[spark] object SparkSubmitDriverBootstrapper {
// Start the driver JVM
val filteredCommand = command.filter(_.nonEmpty)
val builder = new ProcessBuilder(filteredCommand)
+ val env = builder.environment()
+
+ if (submitLibraryPath.isEmpty && confLibraryPath.nonEmpty) {
+ val libraryPaths = confLibraryPath ++ sys.env.get(Utils.libraryPathEnvName)
+ env.put(Utils.libraryPathEnvName, libraryPaths.mkString(sys.props("path.separator")))
+ }
+
val process = builder.start()
+ // If we kill an app while it's running, its sub-process should be killed too.
+ Runtime.getRuntime().addShutdownHook(new Thread() {
+ override def run() = {
+ if (process != null) {
+ process.destroy()
+ process.waitFor()
+ }
+ }
+ })
+
// Redirect stdout and stderr from the child JVM
val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout")
val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr")
@@ -142,14 +149,16 @@ private[spark] object SparkSubmitDriverBootstrapper {
// subprocess there already reads directly from our stdin, so we should avoid spawning a
// thread that contends with the subprocess in reading from System.in.
val isWindows = Utils.isWindows
- val isPySparkShell = sys.env.contains("PYSPARK_SHELL")
+ val isSubprocess = sys.env.contains("IS_SUBPROCESS")
if (!isWindows) {
- val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin")
+ val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin",
+ propagateEof = true)
stdinThread.start()
- // For the PySpark shell, Spark submit itself runs as a python subprocess, and so this JVM
- // should terminate on broken pipe, which signals that the parent process has exited. In
- // Windows, the termination logic for the PySpark shell is handled in java_gateway.py
- if (isPySparkShell) {
+ // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on
+ // broken pipe, signaling that the parent process has exited. This is the case if the
+ // application is launched directly from python, as in the PySpark shell. In Windows,
+ // the termination logic is handled in java_gateway.py
+ if (isSubprocess) {
stdinThread.join()
process.destroy()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 98a93d1fcb2a3..ffe940fbda2fb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -26,7 +26,7 @@ import akka.actor._
import akka.pattern.ask
import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
@@ -47,6 +47,8 @@ private[spark] class AppClient(
conf: SparkConf)
extends Logging {
+ val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))
+
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
@@ -75,9 +77,9 @@ private[spark] class AppClient(
}
def tryRegisterAllMasters() {
- for (masterUrl <- masterUrls) {
- logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
+ for (masterAkkaUrl <- masterAkkaUrls) {
+ logInfo("Connecting to master " + masterAkkaUrl + "...")
+ val actor = context.actorSelection(masterAkkaUrl)
actor ! RegisterApplication(appDescription)
}
}
@@ -103,20 +105,15 @@ private[spark] class AppClient(
}
def changeMaster(url: String) {
+ // activeMasterUrl is a valid Spark url since we receive it from master.
activeMasterUrl = url
- master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
- masterAddress = activeMasterUrl match {
- case Master.sparkUrlRegex(host, port) =>
- Address("akka.tcp", Master.systemName, host, port.toInt)
- case x =>
- throw new SparkException("Invalid spark URL: " + x)
- }
+ master = context.actorSelection(
+ Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem)))
+ masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem))
}
private def isPossibleMaster(remoteUrl: Address) = {
- masterUrls.map(s => Master.toAkkaUrl(s))
- .map(u => AddressFromURIString(u).hostPort)
- .contains(remoteUrl.hostPort)
+ masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort)
}
override def receiveWithLogging = {
@@ -134,6 +131,7 @@ private[spark] class AppClient(
val fullId = appId + "/" + id
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort,
cores))
+ master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)
listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
index fbe39b27649f6..553bf3cb945ab 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
@@ -25,7 +25,8 @@ private[spark] case class ApplicationHistoryInfo(
startTime: Long,
endTime: Long,
lastUpdated: Long,
- sparkUser: String)
+ sparkUser: String,
+ completed: Boolean = false)
private[spark] abstract class ApplicationHistoryProvider {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 481f6c93c6a8d..0ae45f4ad9130 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -17,34 +17,41 @@
package org.apache.spark.deploy.history
-import java.io.FileNotFoundException
+import java.io.{BufferedInputStream, FileNotFoundException, InputStream}
import scala.collection.mutable
import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.permission.AccessControlException
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.scheduler._
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.Utils
+/**
+ * A class that provides application history from event logs stored in the file system.
+ * This provider checks for new finished applications in the background periodically and
+ * renders the history application UI by parsing the associated event logs.
+ */
private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider
with Logging {
+ import FsHistoryProvider._
+
private val NOT_STARTED = ""
// Interval between each check for event log updates
private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval",
conf.getInt("spark.history.updateInterval", 10)) * 1000
- private val logDir = conf.get("spark.history.fs.logDirectory", null)
- private val resolvedLogDir = Option(logDir)
- .map { d => Utils.resolveURI(d) }
- .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") }
+ private val logDir = conf.getOption("spark.history.fs.logDirectory")
+ .map { d => Utils.resolveURI(d).toString }
+ .getOrElse(DEFAULT_LOG_DIR)
- private val fs = Utils.getHadoopFileSystem(resolvedLogDir,
- SparkHadoopUtil.get.newConfiguration(conf))
+ private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf))
// A timestamp of when the disk was last accessed to check for log updates
private var lastLogCheckTimeMs = -1L
@@ -59,6 +66,12 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
@volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo]
= new mutable.LinkedHashMap()
+ // Constants used to parse Spark 1.0.0 log directories.
+ private[history] val LOG_PREFIX = "EVENT_LOG_"
+ private[history] val SPARK_VERSION_PREFIX = "SPARK_VERSION_"
+ private[history] val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_"
+ private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
+
/**
* A background thread that periodically checks for event log updates on disk.
*
@@ -85,21 +98,28 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
initialize()
- private def initialize() {
+ private def initialize(): Unit = {
// Validate the log directory.
- val path = new Path(resolvedLogDir)
+ val path = new Path(logDir)
if (!fs.exists(path)) {
- throw new IllegalArgumentException(
- "Logging directory specified does not exist: %s".format(resolvedLogDir))
+ var msg = s"Log directory specified does not exist: $logDir."
+ if (logDir == DEFAULT_LOG_DIR) {
+ msg += " Did you configure the correct one through spark.fs.history.logDirectory?"
+ }
+ throw new IllegalArgumentException(msg)
}
if (!fs.getFileStatus(path).isDir) {
throw new IllegalArgumentException(
- "Logging directory specified is not a directory: %s".format(resolvedLogDir))
+ "Logging directory specified is not a directory: %s".format(logDir))
}
checkForLogs()
- logCheckingThread.setDaemon(true)
- logCheckingThread.start()
+
+ // Disable the background thread during tests.
+ if (!conf.contains("spark.testing")) {
+ logCheckingThread.setDaemon(true)
+ logCheckingThread.start()
+ }
}
override def getListing() = applications.values
@@ -107,25 +127,26 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
override def getAppUI(appId: String): Option[SparkUI] = {
try {
applications.get(appId).map { info =>
- val (replayBus, appListener) = createReplayBus(fs.getFileStatus(
- new Path(logDir, info.logDir)))
+ val replayBus = new ReplayListenerBus()
val ui = {
val conf = this.conf.clone()
val appSecManager = new SecurityManager(conf)
- new SparkUI(conf, appSecManager, replayBus, appId,
+ SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId,
s"${HistoryServer.UI_PATH_PREFIX}/$appId")
// Do not call ui.bind() to avoid creating a new server for each application
}
- replayBus.replay()
+ val appListener = new ApplicationEventListener()
+ replayBus.addListener(appListener)
+ val appInfo = replay(fs.getFileStatus(new Path(logDir, info.logPath)), replayBus)
- ui.setAppName(s"${appListener.appName.getOrElse(NOT_STARTED)} ($appId)")
+ ui.setAppName(s"${appInfo.name} ($appId)")
val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false)
ui.getSecurityManager.setAcls(uiAclsEnabled)
// make sure to set admin acls before view acls so they are properly picked up
ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse(""))
- ui.getSecurityManager.setViewAcls(appListener.sparkUser.getOrElse(NOT_STARTED),
+ ui.getSecurityManager.setViewAcls(appInfo.sparkUser,
appListener.viewAcls.getOrElse(""))
ui
}
@@ -134,53 +155,45 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
}
}
- override def getConfig(): Map[String, String] =
- Map("Event Log Location" -> resolvedLogDir.toString)
+ override def getConfig(): Map[String, String] = Map("Event log directory" -> logDir.toString)
/**
* Builds the application list based on the current contents of the log directory.
* Tries to reuse as much of the data already in memory as possible, by not reading
* applications that haven't been updated since last time the logs were checked.
*/
- private def checkForLogs() = {
+ private[history] def checkForLogs(): Unit = {
lastLogCheckTimeMs = getMonotonicTimeMs()
logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs))
- try {
- val logStatus = fs.listStatus(new Path(resolvedLogDir))
- val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]()
- // Load all new logs from the log directory. Only directories that have a modification time
- // later than the last known log directory will be loaded.
+ try {
var newLastModifiedTime = lastModifiedTime
- val logInfos = logDirs
- .filter { dir =>
- if (fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE))) {
- val modTime = getModificationTime(dir)
+ val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq)
+ .getOrElse(Seq[FileStatus]())
+ val logInfos = statusList
+ .filter { entry =>
+ try {
+ val modTime = getModificationTime(entry)
newLastModifiedTime = math.max(newLastModifiedTime, modTime)
- modTime > lastModifiedTime
- } else {
- false
+ modTime >= lastModifiedTime
+ } catch {
+ case e: AccessControlException =>
+ // Do not use "logInfo" since these messages can get pretty noisy if printed on
+ // every poll.
+ logDebug(s"No permission to read $entry, ignoring.")
+ false
}
}
- .flatMap { dir =>
+ .flatMap { entry =>
try {
- val (replayBus, appListener) = createReplayBus(dir)
- replayBus.replay()
- Some(new FsApplicationHistoryInfo(
- dir.getPath().getName(),
- appListener.appId.getOrElse(dir.getPath().getName()),
- appListener.appName.getOrElse(NOT_STARTED),
- appListener.startTime.getOrElse(-1L),
- appListener.endTime.getOrElse(-1L),
- getModificationTime(dir),
- appListener.sparkUser.getOrElse(NOT_STARTED)))
+ Some(replay(entry, new ReplayListenerBus()))
} catch {
case e: Exception =>
- logInfo(s"Failed to load application log data from $dir.", e)
+ logError(s"Failed to load application log data from $entry.", e)
None
}
}
- .sortBy { info => -info.endTime }
+ .sortBy { info => (-info.endTime, -info.startTime) }
lastModifiedTime = newLastModifiedTime
@@ -190,7 +203,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
if (!logInfos.isEmpty) {
val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]()
def addIfAbsent(info: FsApplicationHistoryInfo) = {
- if (!newApps.contains(info.id)) {
+ if (!newApps.contains(info.id) ||
+ newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) &&
+ !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) {
newApps += (info.id -> info)
}
}
@@ -210,46 +225,126 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
applications = newApps
}
} catch {
- case t: Throwable => logError("Exception in checking for event log updates", t)
+ case e: Exception => logError("Exception in checking for event log updates", e)
}
}
- private def createReplayBus(logDir: FileStatus): (ReplayListenerBus, ApplicationEventListener) = {
- val path = logDir.getPath()
- val elogInfo = EventLoggingListener.parseLoggingInfo(path, fs)
- val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec)
- val appListener = new ApplicationEventListener
- replayBus.addListener(appListener)
- (replayBus, appListener)
+ /**
+ * Replays the events in the specified log file and returns information about the associated
+ * application.
+ */
+ private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationHistoryInfo = {
+ val logPath = eventLog.getPath()
+ val (logInput, sparkVersion) =
+ if (isLegacyLogDirectory(eventLog)) {
+ openLegacyEventLog(logPath)
+ } else {
+ EventLoggingListener.openEventLog(logPath, fs)
+ }
+ try {
+ val appListener = new ApplicationEventListener
+ bus.addListener(appListener)
+ bus.replay(logInput, sparkVersion)
+ new FsApplicationHistoryInfo(
+ logPath.getName(),
+ appListener.appId.getOrElse(logPath.getName()),
+ appListener.appName.getOrElse(NOT_STARTED),
+ appListener.startTime.getOrElse(-1L),
+ appListener.endTime.getOrElse(-1L),
+ getModificationTime(eventLog),
+ appListener.sparkUser.getOrElse(NOT_STARTED),
+ isApplicationCompleted(eventLog))
+ } finally {
+ logInput.close()
+ }
}
- /** Return when this directory was last modified. */
- private def getModificationTime(dir: FileStatus): Long = {
- try {
- val logFiles = fs.listStatus(dir.getPath)
- if (logFiles != null && !logFiles.isEmpty) {
- logFiles.map(_.getModificationTime).max
- } else {
- dir.getModificationTime
+ /**
+ * Loads a legacy log directory. This assumes that the log directory contains a single event
+ * log file (along with other metadata files), which is the case for directories generated by
+ * the code in previous releases.
+ *
+ * @return 2-tuple of (input stream of the events, version of Spark which wrote the log)
+ */
+ private[history] def openLegacyEventLog(dir: Path): (InputStream, String) = {
+ val children = fs.listStatus(dir)
+ var eventLogPath: Path = null
+ var codecName: Option[String] = None
+ var sparkVersion: String = null
+
+ children.foreach { child =>
+ child.getPath().getName() match {
+ case name if name.startsWith(LOG_PREFIX) =>
+ eventLogPath = child.getPath()
+
+ case codec if codec.startsWith(COMPRESSION_CODEC_PREFIX) =>
+ codecName = Some(codec.substring(COMPRESSION_CODEC_PREFIX.length()))
+
+ case version if version.startsWith(SPARK_VERSION_PREFIX) =>
+ sparkVersion = version.substring(SPARK_VERSION_PREFIX.length())
+
+ case _ =>
}
- } catch {
- case t: Throwable =>
- logError("Exception in accessing modification time of %s".format(dir.getPath), t)
- -1L
+ }
+
+ if (eventLogPath == null || sparkVersion == null) {
+ throw new IllegalArgumentException(s"$dir is not a Spark application log directory.")
+ }
+
+ val codec = try {
+ codecName.map { c => CompressionCodec.createCodec(conf, c) }
+ } catch {
+ case e: Exception =>
+ throw new IllegalArgumentException(s"Unknown compression codec $codecName.")
+ }
+
+ val in = new BufferedInputStream(fs.open(eventLogPath))
+ (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion)
+ }
+
+ /**
+ * Return whether the specified event log path contains a old directory-based event log.
+ * Previously, the event log of an application comprises of multiple files in a directory.
+ * As of Spark 1.3, these files are consolidated into a single one that replaces the directory.
+ * See SPARK-2261 for more detail.
+ */
+ private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDir()
+
+ private def getModificationTime(fsEntry: FileStatus): Long = {
+ if (fsEntry.isDir) {
+ fs.listStatus(fsEntry.getPath).map(_.getModificationTime()).max
+ } else {
+ fsEntry.getModificationTime()
}
}
/** Returns the system's mononotically increasing time. */
- private def getMonotonicTimeMs() = System.nanoTime() / (1000 * 1000)
+ private def getMonotonicTimeMs(): Long = System.nanoTime() / (1000 * 1000)
+
+ /**
+ * Return true when the application has completed.
+ */
+ private def isApplicationCompleted(entry: FileStatus): Boolean = {
+ if (isLegacyLogDirectory(entry)) {
+ fs.exists(new Path(entry.getPath(), APPLICATION_COMPLETE))
+ } else {
+ !entry.getPath().getName().endsWith(EventLoggingListener.IN_PROGRESS)
+ }
+ }
+
+}
+private object FsHistoryProvider {
+ val DEFAULT_LOG_DIR = "file:/tmp/spark-events"
}
private class FsApplicationHistoryInfo(
- val logDir: String,
+ val logPath: String,
id: String,
name: String,
startTime: Long,
endTime: Long,
lastUpdated: Long,
- sparkUser: String)
- extends ApplicationHistoryInfo(id, name, startTime, endTime, lastUpdated, sparkUser)
+ sparkUser: String,
+ completed: Boolean = true)
+ extends ApplicationHistoryInfo(id, name, startTime, endTime, lastUpdated, sparkUser, completed)
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index d25c29113d6da..e4e7bc2216014 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -26,12 +26,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils}
private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
private val pageSize = 20
+ private val plusOrMinus = 2
def render(request: HttpServletRequest): Seq[Node] = {
val requestedPage = Option(request.getParameter("page")).getOrElse("1").toInt
val requestedFirst = (requestedPage - 1) * pageSize
+ val requestedIncomplete =
+ Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean
- val allApps = parent.getApplicationList()
+ val allApps = parent.getApplicationList().filter(_.completed != requestedIncomplete)
val actualFirst = if (requestedFirst < allApps.size) requestedFirst else 0
val apps = allApps.slice(actualFirst, Math.min(actualFirst + pageSize, allApps.size))
@@ -39,6 +42,9 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
val last = Math.min(actualFirst + pageSize, allApps.size) - 1
val pageCount = allApps.size / pageSize + (if (allApps.size % pageSize > 0) 1 else 0)
+ val secondPageFromLeft = 2
+ val secondPageFromRight = pageCount - 1
+
val appTable = UIUtils.listingTable(appHeader, appRow, apps)
val providerConfig = parent.getProviderConfig()
val content =
@@ -48,19 +54,60 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
{providerConfig.map { case (k, v) =>
{k}: {v}
}}
{
+ // This displays the indices of pages that are within `plusOrMinus` pages of
+ // the current page. Regardless of where the current page is, this also links
+ // to the first and last page. If the current page +/- `plusOrMinus` is greater
+ // than the 2nd page from the first page or less than the 2nd page from the last
+ // page, `...` will be displayed.
if (allApps.size > 0) {
+ val leftSideIndices =
+ rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _)
+ val rightSideIndices =
+ rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount)
+
Did you specify the correct logging directory?
+ Please verify your setting of
+ spark.history.fs.logDirectory and whether you have the permissions to
+ access it. It is also possible that your application did not run to
+ completion or did not stop the SparkContext.
+
}
}
+
+ {
+ if (requestedIncomplete) {
+ "Back to completed applications"
+ } else {
+ "Show incomplete applications"
+ }
+ }
+
UIUtils.basicSparkPage(content, "History Server")
@@ -75,20 +122,32 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
"Spark User",
"Last Updated")
+ private def rangeIndices(range: Seq[Int], condition: Int => Boolean): Seq[Node] = {
+ range.filter(condition).map(nextPage => {nextPage} )
+ }
+
private def appRow(info: ApplicationHistoryInfo): Seq[Node] = {
val uiAddress = HistoryServer.UI_PATH_PREFIX + s"/${info.id}"
val startTime = UIUtils.formatDate(info.startTime)
- val endTime = UIUtils.formatDate(info.endTime)
- val duration = UIUtils.formatDuration(info.endTime - info.startTime)
+ val endTime = if (info.endTime > 0) UIUtils.formatDate(info.endTime) else "-"
+ val duration =
+ if (info.endTime > 0) UIUtils.formatDuration(info.endTime - info.startTime) else "-"
val lastUpdated = UIUtils.formatDate(info.lastUpdated)
}
+
+ private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = {
+ "/?" + Array(
+ "page=" + linkPage,
+ "showIncomplete=" + showIncomplete
+ ).mkString("&")
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index ce00c0ffd21e0..fa9bfe5426b6c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -158,11 +158,12 @@ class HistoryServer(
/**
* The recommended way of starting and stopping a HistoryServer is through the scripts
- * start-history-server.sh and stop-history-server.sh. The path to a base log directory
- * is must be specified, while the requested UI port is optional. For example:
+ * start-history-server.sh and stop-history-server.sh. The path to a base log directory,
+ * as well as any other relevant history server configuration, should be specified via
+ * the $SPARK_HISTORY_OPTS environment variable. For example:
*
- * ./sbin/spark-history-server.sh /tmp/spark-events
- * ./sbin/spark-history-server.sh hdfs://1.2.3.4:9000/spark-events
+ * export SPARK_HISTORY_OPTS="-Dspark.history.fs.logDirectory=/tmp/spark-events"
+ * ./sbin/start-history-server.sh
*
* This launches the HistoryServer as a Spark daemon.
*/
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index 5bce32a04d16d..b1270ade9f750 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -17,14 +17,13 @@
package org.apache.spark.deploy.history
-import org.apache.spark.SparkConf
+import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.util.Utils
/**
* Command-line parser for the master.
*/
-private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) {
- private var logDir: String = null
+private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging {
private var propertiesFile: String = null
parse(args.toList)
@@ -32,7 +31,8 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]
private def parse(args: List[String]): Unit = {
args match {
case ("--dir" | "-d") :: value :: tail =>
- logDir = value
+ logWarning("Setting log directory through the command line is deprecated as of " +
+ "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.")
conf.set("spark.history.fs.logDirectory", value)
System.setProperty("spark.history.fs.logDirectory", value)
parse(tail)
@@ -78,9 +78,10 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]
| (default 50)
|FsHistoryProvider options:
|
- | spark.history.fs.logDirectory Directory where app logs are stored (required)
- | spark.history.fs.updateInterval How often to reload log data from storage (in seconds,
- | default 10)
+ | spark.history.fs.logDirectory Directory where app logs are stored
+ | (default: file:/tmp/spark-events)
+ | spark.history.fs.updateInterval How often to reload log data from storage
+ | (in seconds, default: 10)
|""".stripMargin)
System.exit(exitCode)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index c3ca43f8d0734..ede0a9dbefb8d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -24,7 +24,9 @@ import scala.collection.mutable.ArrayBuffer
import akka.actor.ActorRef
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.ApplicationDescription
+import org.apache.spark.util.Utils
private[spark] class ApplicationInfo(
val startTime: Long,
@@ -36,8 +38,8 @@ private[spark] class ApplicationInfo(
extends Serializable {
@transient var state: ApplicationState.Value = _
- @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _
- @transient var removedExecutors: ArrayBuffer[ExecutorInfo] = _
+ @transient var executors: mutable.HashMap[Int, ExecutorDesc] = _
+ @transient var removedExecutors: ArrayBuffer[ExecutorDesc] = _
@transient var coresGranted: Int = _
@transient var endTime: Long = _
@transient var appSource: ApplicationSource = _
@@ -46,19 +48,19 @@ private[spark] class ApplicationInfo(
init()
- private def readObject(in: java.io.ObjectInputStream): Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
private def init() {
state = ApplicationState.WAITING
- executors = new mutable.HashMap[Int, ExecutorInfo]
+ executors = new mutable.HashMap[Int, ExecutorDesc]
coresGranted = 0
endTime = -1L
appSource = new ApplicationSource(this)
nextExecutorId = 0
- removedExecutors = new ArrayBuffer[ExecutorInfo]
+ removedExecutors = new ArrayBuffer[ExecutorDesc]
}
private def newExecutorId(useID: Option[Int] = None): Int = {
@@ -73,14 +75,14 @@ private[spark] class ApplicationInfo(
}
}
- def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = {
- val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
+ def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorDesc = {
+ val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
executors(exec.id) = exec
coresGranted += cores
exec
}
- def removeExecutor(exec: ExecutorInfo) {
+ def removeExecutor(exec: ExecutorDesc) {
if (executors.contains(exec.id)) {
removedExecutors += executors(exec.id)
executors -= exec.id
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
index 80b570a44af18..9d3d7938c6ccb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
@@ -19,7 +19,9 @@ package org.apache.spark.deploy.master
import java.util.Date
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.DriverDescription
+import org.apache.spark.util.Utils
private[spark] class DriverInfo(
val startTime: Long,
@@ -36,7 +38,7 @@ private[spark] class DriverInfo(
init()
- private def readObject(in: java.io.ObjectInputStream): Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
new file mode 100644
index 0000000000000..5d620dfcabad5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
+
+private[spark] class ExecutorDesc(
+ val id: Int,
+ val application: ApplicationInfo,
+ val worker: WorkerInfo,
+ val cores: Int,
+ val memory: Int) {
+
+ var state = ExecutorState.LAUNCHING
+
+ /** Copy all state (non-val) variables from the given on-the-wire ExecutorDescription. */
+ def copyState(execDesc: ExecutorDescription) {
+ state = execDesc.state
+ }
+
+ def fullId: String = application.id + "/" + id
+
+ override def equals(other: Any): Boolean = {
+ other match {
+ case info: ExecutorDesc =>
+ fullId == info.fullId &&
+ worker.id == info.worker.id &&
+ cores == info.cores &&
+ memory == info.memory
+ case _ => false
+ }
+ }
+
+ override def toString: String = fullId
+
+ override def hashCode: Int = toString.hashCode()
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
deleted file mode 100644
index d417070c51016..0000000000000
--- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
+++ /dev/null
@@ -1,52 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.deploy.master
-
-import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
-
-private[spark] class ExecutorInfo(
- val id: Int,
- val application: ApplicationInfo,
- val worker: WorkerInfo,
- val cores: Int,
- val memory: Int) {
-
- var state = ExecutorState.LAUNCHING
-
- /** Copy all state (non-val) variables from the given on-the-wire ExecutorDescription. */
- def copyState(execDesc: ExecutorDescription) {
- state = execDesc.state
- }
-
- def fullId: String = application.id + "/" + id
-
- override def equals(other: Any): Boolean = {
- other match {
- case info: ExecutorInfo =>
- fullId == info.fullId &&
- worker.id == info.worker.id &&
- cores == info.cores &&
- memory == info.memory
- case _ => false
- }
- }
-
- override def toString: String = fullId
-
- override def hashCode: Int = toString.hashCode()
-}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index 08a99bbe68578..36a2e2c6a6349 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -19,10 +19,13 @@ package org.apache.spark.deploy.master
import java.io._
+import scala.reflect.ClassTag
+
import akka.serialization.Serialization
import org.apache.spark.Logging
+
/**
* Stores data in a single on-disk directory with one file per application and worker.
* Files are deleted when applications and workers are removed.
@@ -37,51 +40,24 @@ private[spark] class FileSystemPersistenceEngine(
new File(dir).mkdir()
- override def addApplication(app: ApplicationInfo) {
- val appFile = new File(dir + File.separator + "app_" + app.id)
- serializeIntoFile(appFile, app)
- }
-
- override def removeApplication(app: ApplicationInfo) {
- new File(dir + File.separator + "app_" + app.id).delete()
- }
-
- override def addDriver(driver: DriverInfo) {
- val driverFile = new File(dir + File.separator + "driver_" + driver.id)
- serializeIntoFile(driverFile, driver)
- }
-
- override def removeDriver(driver: DriverInfo) {
- new File(dir + File.separator + "driver_" + driver.id).delete()
- }
-
- override def addWorker(worker: WorkerInfo) {
- val workerFile = new File(dir + File.separator + "worker_" + worker.id)
- serializeIntoFile(workerFile, worker)
+ override def persist(name: String, obj: Object): Unit = {
+ serializeIntoFile(new File(dir + File.separator + name), obj)
}
- override def removeWorker(worker: WorkerInfo) {
- new File(dir + File.separator + "worker_" + worker.id).delete()
+ override def unpersist(name: String): Unit = {
+ new File(dir + File.separator + name).delete()
}
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
- val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
- val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo])
- val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
- (apps, drivers, workers)
+ override def read[T: ClassTag](prefix: String) = {
+ val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix))
+ files.map(deserializeFromFile[T])
}
private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
-
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
-
val out = new FileOutputStream(file)
try {
out.write(serialized)
@@ -90,7 +66,7 @@ private[spark] class FileSystemPersistenceEngine(
}
}
- def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
+ private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
try {
@@ -98,9 +74,9 @@ private[spark] class FileSystemPersistenceEngine(
} finally {
dis.close()
}
-
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
index 4433a2ec29be6..cf77c86d760cf 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala
@@ -17,30 +17,27 @@
package org.apache.spark.deploy.master
-import akka.actor.{Actor, ActorRef}
-
-import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
+import org.apache.spark.annotation.DeveloperApi
/**
- * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it
- * is the only Master serving requests.
- * In addition to the API provided, the LeaderElectionAgent will use of the following messages
- * to inform the Master of leader changes:
- * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]]
- * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
+ * :: DeveloperApi ::
+ *
+ * A LeaderElectionAgent tracks current master and is a common interface for all election Agents.
*/
-private[spark] trait LeaderElectionAgent extends Actor {
- // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring.
- val masterActor: ActorRef
+@DeveloperApi
+trait LeaderElectionAgent {
+ val masterActor: LeaderElectable
+ def stop() {} // to avoid noops in implementations.
}
-/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
-private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent {
- override def preStart() {
- masterActor ! ElectedLeader
- }
+@DeveloperApi
+trait LeaderElectable {
+ def electedLeader()
+ def revokedLeadership()
+}
- override def receive = {
- case _ =>
- }
+/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
+private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable)
+ extends LeaderElectionAgent {
+ masterActor.electedLeader()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 3b6bb9fe128a4..5eeb9fe526248 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -17,6 +17,7 @@
package org.apache.spark.deploy.master
+import java.io.FileNotFoundException
import java.net.URLEncoder
import java.text.SimpleDateFormat
import java.util.Date
@@ -30,7 +31,9 @@ import scala.util.Random
import akka.actor._
import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+import akka.serialization.Serialization
import akka.serialization.SerializationExtension
+import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
@@ -50,18 +53,18 @@ private[spark] class Master(
port: Int,
webUiPort: Int,
val securityMgr: SecurityManager)
- extends Actor with ActorLogReceive with Logging {
+ extends Actor with ActorLogReceive with Logging with LeaderElectable {
import context.dispatcher // to use Akka's scheduler.schedule()
val conf = new SparkConf
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000
val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200)
val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200)
val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15)
- val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE")
val workers = new HashSet[WorkerInfo]
@@ -103,7 +106,7 @@ private[spark] class Master(
var persistenceEngine: PersistenceEngine = _
- var leaderElectionAgent: ActorRef = _
+ var leaderElectionAgent: LeaderElectionAgent = _
private var recoveryCompletionTask: Cancellable = _
@@ -120,6 +123,7 @@ private[spark] class Master(
override def preStart() {
logInfo("Starting Spark master at " + masterUrl)
+ logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.bind()
@@ -129,24 +133,32 @@ private[spark] class Master(
masterMetricsSystem.registerSource(masterSource)
masterMetricsSystem.start()
applicationMetricsSystem.start()
+ // Attach the master and app metrics servlet handler to the web ui after the metrics systems are
+ // started.
+ masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
+ applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
- persistenceEngine = RECOVERY_MODE match {
+ val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
- new ZooKeeperPersistenceEngine(SerializationExtension(context.system), conf)
+ val zkFactory =
+ new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system))
+ (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this))
case "FILESYSTEM" =>
- logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
- new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
+ val fsFactory =
+ new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system))
+ (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this))
+ case "CUSTOM" =>
+ val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory"))
+ val factory = clazz.getConstructor(conf.getClass, Serialization.getClass)
+ .newInstance(conf, SerializationExtension(context.system))
+ .asInstanceOf[StandaloneRecoveryModeFactory]
+ (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this))
case _ =>
- new BlackHolePersistenceEngine()
+ (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this))
}
-
- leaderElectionAgent = RECOVERY_MODE match {
- case "ZOOKEEPER" =>
- context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl, conf))
- case _ =>
- context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
- }
+ persistenceEngine = persistenceEngine_
+ leaderElectionAgent = leaderElectionAgent_
}
override def preRestart(reason: Throwable, message: Option[Any]) {
@@ -165,7 +177,15 @@ private[spark] class Master(
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
persistenceEngine.close()
- context.stop(leaderElectionAgent)
+ leaderElectionAgent.stop()
+ }
+
+ override def electedLeader() {
+ self ! ElectedLeader
+ }
+
+ override def revokedLeadership() {
+ self ! RevokedLeadership
}
override def receiveWithLogging = {
@@ -498,7 +518,7 @@ private[spark] class Master(
val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE))
val numWorkersAlive = shuffledAliveWorkers.size
var curPos = 0
-
+
for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers
// We assign workers to each waiting driver in a round-robin fashion. For each driver, we
// start from the last worker that was assigned a driver, and continue onwards until we have
@@ -561,7 +581,7 @@ private[spark] class Master(
}
}
- def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) {
+ def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(masterUrl,
@@ -685,6 +705,11 @@ private[spark] class Master(
}
persistenceEngine.removeApplication(app)
schedule()
+
+ // Tell all workers that the application has finished, so they can clean up any app state.
+ workers.foreach { w =>
+ w.actor ! ApplicationFinished(app.id)
+ }
}
}
@@ -695,41 +720,51 @@ private[spark] class Master(
def rebuildSparkUI(app: ApplicationInfo): Boolean = {
val appName = app.desc.name
val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found"
- val eventLogDir = app.desc.eventLogDir.getOrElse {
- // Event logging is not enabled for this application
- app.desc.appUiUrl = notFoundBasePath
- return false
- }
-
- val appEventLogDir = EventLoggingListener.getLogDirPath(eventLogDir, app.id)
- val fileSystem = Utils.getHadoopFileSystem(appEventLogDir,
- SparkHadoopUtil.get.newConfiguration(conf))
- val eventLogInfo = EventLoggingListener.parseLoggingInfo(appEventLogDir, fileSystem)
- val eventLogPaths = eventLogInfo.logPaths
- val compressionCodec = eventLogInfo.compressionCodec
-
- if (eventLogPaths.isEmpty) {
- // Event logging is enabled for this application, but no event logs are found
- val title = s"Application history not found (${app.id})"
- var msg = s"No event logs found for application $appName in $appEventLogDir."
- logWarning(msg)
- msg += " Did you specify the correct logging directory?"
- msg = URLEncoder.encode(msg, "UTF-8")
- app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title"
- return false
- }
-
try {
- val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec)
- val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)",
- HistoryServer.UI_PATH_PREFIX + s"/${app.id}")
- replayBus.replay()
+ val eventLogFile = app.desc.eventLogDir
+ .map { dir => EventLoggingListener.getLogPath(dir, app.id) }
+ .getOrElse {
+ // Event logging is not enabled for this application
+ app.desc.appUiUrl = notFoundBasePath
+ return false
+ }
+
+ val fs = Utils.getHadoopFileSystem(eventLogFile, hadoopConf)
+
+ if (fs.exists(new Path(eventLogFile + EventLoggingListener.IN_PROGRESS))) {
+ // Event logging is enabled for this application, but the application is still in progress
+ val title = s"Application history not found (${app.id})"
+ var msg = s"Application $appName is still in progress."
+ logWarning(msg)
+ msg = URLEncoder.encode(msg, "UTF-8")
+ app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title"
+ return false
+ }
+
+ val (logInput, sparkVersion) = EventLoggingListener.openEventLog(new Path(eventLogFile), fs)
+ val replayBus = new ReplayListenerBus()
+ val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf),
+ appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}")
+ try {
+ replayBus.replay(logInput, sparkVersion)
+ } finally {
+ logInput.close()
+ }
appIdToUI(app.id) = ui
webUi.attachSparkUI(ui)
// Application UI is successfully rebuilt, so link the Master UI to it
- app.desc.appUiUrl = ui.getBasePath
+ app.desc.appUiUrl = ui.basePath
true
} catch {
+ case fnf: FileNotFoundException =>
+ // Event logging is enabled for this application, but no event logs are found
+ val title = s"Application history not found (${app.id})"
+ var msg = s"No event logs found for application $appName in ${app.desc.eventLogDir}."
+ logWarning(msg)
+ msg += " Did you specify the correct logging directory?"
+ msg = URLEncoder.encode(msg, "UTF-8")
+ app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title"
+ false
case e: Exception =>
// Relay exception message to application UI page
val title = s"Application history load error (${app.id})"
@@ -811,7 +846,6 @@ private[spark] class Master(
private[spark] object Master extends Logging {
val systemName = "sparkMaster"
private val actorName = "Master"
- val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
def main(argStrings: Array[String]) {
SignalLogger.register(log)
@@ -821,14 +855,24 @@ private[spark] object Master extends Logging {
actorSystem.awaitTermination()
}
- /** Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
- def toAkkaUrl(sparkUrl: String): String = {
- sparkUrl match {
- case sparkUrlRegex(host, port) =>
- "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
- case _ =>
- throw new SparkException("Invalid master URL: " + sparkUrl)
- }
+ /**
+ * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`.
+ *
+ * @throws SparkException if the url is invalid
+ */
+ def toAkkaUrl(sparkUrl: String, protocol: String): String = {
+ val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
+ AkkaUtils.address(protocol, systemName, host, port, actorName)
+ }
+
+ /**
+ * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`.
+ *
+ * @throws SparkException if the url is invalid
+ */
+ def toAkkaAddress(sparkUrl: String, protocol: String): Address = {
+ val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
+ Address(protocol, systemName, host, port)
}
def startSystemAndActor(
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
index e3640ea4f7e64..2e0e1e7036ac8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala
@@ -17,6 +17,10 @@
package org.apache.spark.deploy.master
+import org.apache.spark.annotation.DeveloperApi
+
+import scala.reflect.ClassTag
+
/**
* Allows Master to persist any state that is necessary in order to recover from a failure.
* The following semantics are required:
@@ -25,36 +29,70 @@ package org.apache.spark.deploy.master
* Given these two requirements, we will have all apps and workers persisted, but
* we might not have yet deleted apps or workers that finished (so their liveness must be verified
* during recovery).
+ *
+ * The implementation of this trait defines how name-object pairs are stored or retrieved.
*/
-private[spark] trait PersistenceEngine {
- def addApplication(app: ApplicationInfo)
+@DeveloperApi
+trait PersistenceEngine {
- def removeApplication(app: ApplicationInfo)
+ /**
+ * Defines how the object is serialized and persisted. Implementation will
+ * depend on the store used.
+ */
+ def persist(name: String, obj: Object)
- def addWorker(worker: WorkerInfo)
+ /**
+ * Defines how the object referred by its name is removed from the store.
+ */
+ def unpersist(name: String)
- def removeWorker(worker: WorkerInfo)
+ /**
+ * Gives all objects, matching a prefix. This defines how objects are
+ * read/deserialized back.
+ */
+ def read[T: ClassTag](prefix: String): Seq[T]
- def addDriver(driver: DriverInfo)
+ final def addApplication(app: ApplicationInfo): Unit = {
+ persist("app_" + app.id, app)
+ }
- def removeDriver(driver: DriverInfo)
+ final def removeApplication(app: ApplicationInfo): Unit = {
+ unpersist("app_" + app.id)
+ }
+
+ final def addWorker(worker: WorkerInfo): Unit = {
+ persist("worker_" + worker.id, worker)
+ }
+
+ final def removeWorker(worker: WorkerInfo): Unit = {
+ unpersist("worker_" + worker.id)
+ }
+
+ final def addDriver(driver: DriverInfo): Unit = {
+ persist("driver_" + driver.id, driver)
+ }
+
+ final def removeDriver(driver: DriverInfo): Unit = {
+ unpersist("driver_" + driver.id)
+ }
/**
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
- def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo])
+ final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
+ (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_"))
+ }
def close() {}
}
private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
- override def addApplication(app: ApplicationInfo) {}
- override def removeApplication(app: ApplicationInfo) {}
- override def addWorker(worker: WorkerInfo) {}
- override def removeWorker(worker: WorkerInfo) {}
- override def addDriver(driver: DriverInfo) {}
- override def removeDriver(driver: DriverInfo) {}
-
- override def readPersistedData() = (Nil, Nil, Nil)
+
+ override def persist(name: String, obj: Object): Unit = {}
+
+ override def unpersist(name: String): Unit = {}
+
+ override def read[T: ClassTag](name: String): Seq[T] = Nil
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
new file mode 100644
index 0000000000000..1096eb0368357
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.master
+
+import akka.serialization.Serialization
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * ::DeveloperApi::
+ *
+ * Implementation of this class can be plugged in as recovery mode alternative for Spark's
+ * Standalone mode.
+ *
+ */
+@DeveloperApi
+abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) {
+
+ /**
+ * PersistenceEngine defines how the persistent data(Information about worker, driver etc..)
+ * is handled for recovery.
+ *
+ */
+ def createPersistenceEngine(): PersistenceEngine
+
+ /**
+ * Create an instance of LeaderAgent that decides who gets elected as master.
+ */
+ def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent
+}
+
+/**
+ * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual
+ * recovery is made by restoring from filesystem.
+ */
+private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+ extends StandaloneRecoveryModeFactory(conf, serializer) with Logging {
+ val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "")
+
+ def createPersistenceEngine() = {
+ logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
+ new FileSystemPersistenceEngine(RECOVERY_DIR, serializer)
+ }
+
+ def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master)
+}
+
+private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization)
+ extends StandaloneRecoveryModeFactory(conf, serializer) {
+ def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer)
+
+ def createLeaderElectionAgent(master: LeaderElectable) =
+ new ZooKeeperLeaderElectionAgent(master, conf)
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index c5fa9cf7d7c2d..e94aae93e4495 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import akka.actor.ActorRef
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
@@ -37,7 +38,7 @@ private[spark] class WorkerInfo(
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
- @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // executorId => info
+ @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
@transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
@transient var state: WorkerState.Value = _
@transient var coresUsed: Int = _
@@ -50,7 +51,7 @@ private[spark] class WorkerInfo(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
- private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
@@ -69,13 +70,13 @@ private[spark] class WorkerInfo(
host + ":" + port
}
- def addExecutor(exec: ExecutorInfo) {
+ def addExecutor(exec: ExecutorDesc) {
executors(exec.fullId) = exec
coresUsed += exec.cores
memoryUsed += exec.memory
}
- def removeExecutor(exec: ExecutorInfo) {
+ def removeExecutor(exec: ExecutorDesc) {
if (executors.contains(exec.fullId)) {
executors -= exec.fullId
coresUsed -= exec.cores
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
index 285f9b014e291..8eaa0ad948519 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala
@@ -24,9 +24,8 @@ import org.apache.spark.deploy.master.MasterMessages._
import org.apache.curator.framework.CuratorFramework
import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch}
-private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
- masterUrl: String, conf: SparkConf)
- extends LeaderElectionAgent with LeaderLatchListener with Logging {
+private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable,
+ conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging {
val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
@@ -34,30 +33,21 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
private var leaderLatch: LeaderLatch = _
private var status = LeadershipStatus.NOT_LEADER
- override def preStart() {
+ start()
+ def start() {
logInfo("Starting ZooKeeper LeaderElection agent")
zk = SparkCuratorUtil.newClient(conf)
leaderLatch = new LeaderLatch(zk, WORKING_DIR)
leaderLatch.addListener(this)
-
leaderLatch.start()
}
- override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
- logError("LeaderElectionAgent failed...", reason)
- super.preRestart(reason, message)
- }
-
- override def postStop() {
+ override def stop() {
leaderLatch.close()
zk.close()
}
- override def receive = {
- case _ =>
- }
-
override def isLeader() {
synchronized {
// could have lost leadership by now.
@@ -85,10 +75,10 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef,
def updateLeadershipStatus(isLeader: Boolean) {
if (isLeader && status == LeadershipStatus.NOT_LEADER) {
status = LeadershipStatus.LEADER
- masterActor ! ElectedLeader
+ masterActor.electedLeader()
} else if (!isLeader && status == LeadershipStatus.LEADER) {
status = LeadershipStatus.NOT_LEADER
- masterActor ! RevokedLeadership
+ masterActor.revokedLeadership()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index 834dfedee52ce..e11ac031fb9c6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -17,15 +17,18 @@
package org.apache.spark.deploy.master
+import akka.serialization.Serialization
+
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
-import akka.serialization.Serialization
import org.apache.curator.framework.CuratorFramework
import org.apache.zookeeper.CreateMode
import org.apache.spark.{Logging, SparkConf}
-class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
+
+private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization)
extends PersistenceEngine
with Logging
{
@@ -34,52 +37,31 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf)
SparkCuratorUtil.mkdir(zk, WORKING_DIR)
- override def addApplication(app: ApplicationInfo) {
- serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
- }
- override def removeApplication(app: ApplicationInfo) {
- zk.delete().forPath(WORKING_DIR + "/app_" + app.id)
+ override def persist(name: String, obj: Object): Unit = {
+ serializeIntoFile(WORKING_DIR + "/" + name, obj)
}
- override def addDriver(driver: DriverInfo) {
- serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver)
+ override def unpersist(name: String): Unit = {
+ zk.delete().forPath(WORKING_DIR + "/" + name)
}
- override def removeDriver(driver: DriverInfo) {
- zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id)
- }
-
- override def addWorker(worker: WorkerInfo) {
- serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
- }
-
- override def removeWorker(worker: WorkerInfo) {
- zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id)
+ override def read[T: ClassTag](prefix: String) = {
+ val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix))
+ file.map(deserializeFromFile[T]).flatten
}
override def close() {
zk.close()
}
- override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = {
- val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted
- val appFiles = sortedFiles.filter(_.startsWith("app_"))
- val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten
- val driverFiles = sortedFiles.filter(_.startsWith("driver_"))
- val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten
- val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
- val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten
- (apps, drivers, workers)
- }
-
private def serializeIntoFile(path: String, value: AnyRef) {
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized)
}
- def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = {
+ def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = {
val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename)
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 4588c130ef439..3aae2b95d7396 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -27,7 +27,7 @@ import org.json4s.JValue
import org.apache.spark.deploy.{ExecutorState, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
-import org.apache.spark.deploy.master.ExecutorInfo
+import org.apache.spark.deploy.master.ExecutorDesc
import org.apache.spark.ui.{UIUtils, WebUIPage}
import org.apache.spark.util.Utils
@@ -109,7 +109,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app
UIUtils.basicSparkPage(content, "Application: " + app.desc.name)
}
- private def executorRow(executor: ExecutorInfo): Seq[Node] = {
+ private def executorRow(executor: ExecutorDesc): Seq[Node] = {
{executor.id}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index d86ec1e03e45c..73400c5affb5d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -41,8 +41,6 @@ class MasterWebUI(val master: Master, requestedPort: Int)
attachPage(new HistoryNotFoundPage(this))
attachPage(new MasterPage(this))
attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static"))
- master.masterMetricsSystem.getServletHandlers.foreach(attachHandler)
- master.applicationMetricsSystem.getServletHandlers.foreach(attachHandler)
}
/** Attach a reconstructed UI to this Master UI. Only valid after bind(). */
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index 2e9be2a180c68..28e9662db5da9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -20,6 +20,8 @@ package org.apache.spark.deploy.worker
import java.io.{File, FileOutputStream, InputStream, IOException}
import java.lang.System._
+import scala.collection.Map
+
import org.apache.spark.Logging
import org.apache.spark.deploy.Command
import org.apache.spark.util.Utils
@@ -29,7 +31,29 @@ import org.apache.spark.util.Utils
*/
private[spark]
object CommandUtils extends Logging {
- def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+
+ /**
+ * Build a ProcessBuilder based on the given parameters.
+ * The `env` argument is exposed for testing.
+ */
+ def buildProcessBuilder(
+ command: Command,
+ memory: Int,
+ sparkHome: String,
+ substituteArguments: String => String,
+ classPaths: Seq[String] = Seq[String](),
+ env: Map[String, String] = sys.env): ProcessBuilder = {
+ val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env)
+ val commandSeq = buildCommandSeq(localCommand, memory, sparkHome)
+ val builder = new ProcessBuilder(commandSeq: _*)
+ val environment = builder.environment()
+ for ((key, value) <- localCommand.environment) {
+ environment.put(key, value)
+ }
+ builder
+ }
+
+ private def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java")
// SPARK-698: do not call the run.cmd script, as process.destroy()
@@ -38,11 +62,41 @@ object CommandUtils extends Logging {
command.arguments
}
+ /**
+ * Build a command based on the given one, taking into account the local environment
+ * of where this command is expected to run, substitute any placeholders, and append
+ * any extra class paths.
+ */
+ private def buildLocalCommand(
+ command: Command,
+ substituteArguments: String => String,
+ classPath: Seq[String] = Seq[String](),
+ env: Map[String, String]): Command = {
+ val libraryPathName = Utils.libraryPathEnvName
+ val libraryPathEntries = command.libraryPathEntries
+ val cmdLibraryPath = command.environment.get(libraryPathName)
+
+ val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) {
+ val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName)
+ command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator)))
+ } else {
+ command.environment
+ }
+
+ Command(
+ command.mainClass,
+ command.arguments.map(substituteArguments),
+ newEnvironment,
+ command.classPathEntries ++ classPath,
+ Seq[String](), // library path already captured in environment variable
+ command.javaOpts)
+ }
+
/**
* Attention: this must always be aligned with the environment variables in the run scripts and
* the way the JAVA_OPTS are assembled there.
*/
- def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+ private def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M")
// Exists for backwards compatibility with older Spark versions
@@ -53,14 +107,6 @@ object CommandUtils extends Logging {
logWarning("Set SPARK_LOCAL_DIRS for node-specific storage locations.")
}
- val libraryOpts =
- if (command.libraryPathEntries.size > 0) {
- val joined = command.libraryPathEntries.mkString(File.pathSeparator)
- Seq(s"-Djava.library.path=$joined")
- } else {
- Seq()
- }
-
// Figure out our classpath with the external compute-classpath script
val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
val classPath = Utils.executeAndGetOutput(
@@ -71,7 +117,7 @@ object CommandUtils extends Logging {
val javaVersion = System.getProperty("java.version")
val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None
Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++
- permGenOpt ++ libraryOpts ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts
+ permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts
}
/** Spawn a thread that will redirect a given stream to a file */
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 9f9911762505a..28cab36c7b9e2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConversions._
import scala.collection.Map
import akka.actor.ActorRef
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileUtil, Path}
@@ -76,17 +76,9 @@ private[spark] class DriverRunner(
// Make sure user application jar is on the classpath
// TODO: If we add ability to submit multiple jars they should also be added here
- val classPath = driverDesc.command.classPathEntries ++ Seq(s"$localJarFilename")
- val newCommand = Command(
- driverDesc.command.mainClass,
- driverDesc.command.arguments.map(substituteVariables),
- driverDesc.command.environment,
- classPath,
- driverDesc.command.libraryPathEntries,
- driverDesc.command.javaOpts)
- val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem,
- sparkHome.getAbsolutePath)
- launchDriver(command, driverDesc.command.environment, driverDir, driverDesc.supervise)
+ val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem,
+ sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename))
+ launchDriver(builder, driverDir, driverDesc.supervise)
}
catch {
case e: Exception => finalException = Some(e)
@@ -165,11 +157,8 @@ private[spark] class DriverRunner(
localJarFilename
}
- private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File,
- supervise: Boolean) {
- val builder = new ProcessBuilder(command: _*).directory(baseDir)
- envVars.map{ case(k,v) => builder.environment().put(k, v) }
-
+ private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) {
+ builder.directory(baseDir)
def initialize(process: Process) = {
// Redirect stdout and stderr to files
val stdout = new File(baseDir, "stdout")
@@ -177,8 +166,8 @@ private[spark] class DriverRunner(
val stderr = new File(baseDir, "stderr")
val header = "Launch Command: %s\n%s\n\n".format(
- command.mkString("\"", "\" \"", "\""), "=" * 40)
- Files.append(header, stderr, Charsets.UTF_8)
+ builder.command.mkString("\"", "\" \"", "\""), "=" * 40)
+ Files.append(header, stderr, UTF_8)
CommandUtils.redirectStream(process.getErrorStream, stderr)
}
runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 71d7385b08eb9..bc9f78b9e5c77 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -19,12 +19,14 @@ package org.apache.spark.deploy.worker
import java.io._
+import scala.collection.JavaConversions._
+
import akka.actor.ActorRef
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.spark.{SparkConf, Logging}
-import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
+import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
import org.apache.spark.util.logging.FileAppender
@@ -45,6 +47,7 @@ private[spark] class ExecutorRunner(
val executorDir: File,
val workerUrl: String,
val conf: SparkConf,
+ val appLocalDirs: Seq[String],
var state: ExecutorState.Value)
extends Logging {
@@ -75,7 +78,7 @@ private[spark] class ExecutorRunner(
/**
* Kill executor process, wait for exit and notify worker to update resource status.
*
- * @param message the exception message which caused the executor's death
+ * @param message the exception message which caused the executor's death
*/
private def killProcess(message: Option[String]) {
var exitCode: Option[Int] = None
@@ -115,33 +118,22 @@ private[spark] class ExecutorRunner(
case other => other
}
- def getCommandSeq = {
- val command = Command(
- appDesc.command.mainClass,
- appDesc.command.arguments.map(substituteVariables),
- appDesc.command.environment,
- appDesc.command.classPathEntries,
- appDesc.command.libraryPathEntries,
- appDesc.command.javaOpts)
- CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath)
- }
-
/**
* Download and run the executor described in our ApplicationDescription
*/
def fetchAndRunExecutor() {
try {
// Launch the process
- val command = getCommandSeq
+ val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory,
+ sparkHome.getAbsolutePath, substituteVariables)
+ val command = builder.command()
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
- val builder = new ProcessBuilder(command: _*).directory(executorDir)
- val env = builder.environment()
- for ((key, value) <- appDesc.command.environment) {
- env.put(key, value)
- }
+
+ builder.directory(executorDir)
+ builder.environment.put("SPARK_LOCAL_DIRS", appLocalDirs.mkString(","))
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
- env.put("SPARK_LAUNCH_WITH_SCALA", "0")
+ builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0")
process = builder.start()
val header = "Spark Executor Command: %s\n%s\n\n".format(
command.mkString("\"", "\" \"", "\""), "=" * 40)
@@ -151,11 +143,9 @@ private[spark] class ExecutorRunner(
stdoutAppender = FileAppender(process.getInputStream, stdout, conf)
val stderr = new File(executorDir, "stderr")
- Files.write(header, stderr, Charsets.UTF_8)
+ Files.write(header, stderr, UTF_8)
stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
- state = ExecutorState.RUNNING
- worker ! ExecutorStateChanged(appId, execId, state, None, None)
// Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown)
// or with nonzero exit code
val exitCode = process.waitFor()
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
new file mode 100644
index 0000000000000..b9798963bab0a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import org.apache.spark.{Logging, SparkConf, SecurityManager}
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.sasl.SaslRpcHandler
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
+
+/**
+ * Provides a server from which Executors can read shuffle files (rather than reading directly from
+ * each other), to provide uninterrupted access to the files in the face of executors being turned
+ * off or killed.
+ *
+ * Optionally requires SASL authentication in order to read. See [[SecurityManager]].
+ */
+private[worker]
+class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager)
+ extends Logging {
+
+ private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
+ private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
+ private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
+
+ private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
+ private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
+ private val transportContext: TransportContext = {
+ val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
+ new TransportContext(transportConf, handler)
+ }
+
+ private var server: TransportServer = _
+
+ /** Starts the external shuffle service if the user has configured us to. */
+ def startIfEnabled() {
+ if (enabled) {
+ require(server == null, "Shuffle server already started")
+ logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
+ server = transportContext.createServer(port)
+ }
+ }
+
+ def stop() {
+ if (enabled && server != null) {
+ server.close()
+ server = null
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index c4a8ec2e5e7b0..b20f5c0c82895 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -21,10 +21,9 @@ import java.io.File
import java.io.IOException
import java.text.SimpleDateFormat
import java.util.{UUID, Date}
-import java.util.concurrent.TimeUnit
import scala.collection.JavaConversions._
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{HashMap, HashSet}
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random
@@ -32,8 +31,8 @@ import scala.util.Random
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
-import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.deploy.worker.ui.WorkerWebUI
@@ -41,7 +40,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
/**
- * @param masterUrls Each url should look like spark://host:port.
+ * @param masterAkkaUrls Each url should be a valid akka url.
*/
private[spark] class Worker(
host: String,
@@ -49,7 +48,7 @@ private[spark] class Worker(
webUiPort: Int,
cores: Int,
memory: Int,
- masterUrls: Array[String],
+ masterAkkaUrls: Array[String],
actorSystemName: String,
actorName: String,
workDirPath: String = null,
@@ -94,7 +93,12 @@ private[spark] class Worker(
var masterAddress: Address = null
var activeMasterUrl: String = ""
var activeMasterWebUiUrl : String = ""
- val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName)
+ val akkaUrl = AkkaUtils.address(
+ AkkaUtils.protocol(context.system),
+ actorSystemName,
+ host,
+ port,
+ actorName)
@volatile var registered = false
@volatile var connected = false
val workerId = generateWorkerId()
@@ -110,6 +114,11 @@ private[spark] class Worker(
val finishedExecutors = new HashMap[String, ExecutorRunner]
val drivers = new HashMap[String, DriverRunner]
val finishedDrivers = new HashMap[String, DriverRunner]
+ val appDirectories = new HashMap[String, Seq[String]]
+ val finishedApps = new HashSet[String]
+
+ // The shuffle service is not actually started unless configured.
+ val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr)
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
@@ -151,52 +160,89 @@ private[spark] class Worker(
assert(!registered)
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
host, port, cores, Utils.megabytesToString(memory)))
+ logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
registerWithMaster()
metricsSystem.registerSource(workerSource)
metricsSystem.start()
+ // Attach the worker metrics servlet handler to the web ui after the metrics system is started.
+ metricsSystem.getServletHandlers.foreach(webUi.attachHandler)
}
def changeMaster(url: String, uiUrl: String) {
+ // activeMasterUrl it's a valid Spark url since we receive it from master.
activeMasterUrl = url
activeMasterWebUiUrl = uiUrl
- master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
- masterAddress = activeMasterUrl match {
- case Master.sparkUrlRegex(_host, _port) =>
- Address("akka.tcp", Master.systemName, _host, _port.toInt)
- case x =>
- throw new SparkException("Invalid spark URL: " + x)
- }
+ master = context.actorSelection(
+ Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system)))
+ masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system))
connected = true
+ // Cancel any outstanding re-registration attempts because we found a new master
+ registrationRetryTimer.foreach(_.cancel())
+ registrationRetryTimer = None
}
private def tryRegisterAllMasters() {
- for (masterUrl <- masterUrls) {
- logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
+ for (masterAkkaUrl <- masterAkkaUrls) {
+ logInfo("Connecting to master " + masterAkkaUrl + "...")
+ val actor = context.actorSelection(masterAkkaUrl)
actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
}
}
- private def retryConnectToMaster() {
+ /**
+ * Re-register with the master because a network failure or a master failure has occurred.
+ * If the re-registration attempt threshold is exceeded, the worker exits with error.
+ * Note that for thread-safety this should only be called from the actor.
+ */
+ private def reregisterWithMaster(): Unit = {
Utils.tryOrExit {
connectionAttemptCount += 1
- logInfo(s"Attempting to connect to master (attempt # $connectionAttemptCount")
if (registered) {
registrationRetryTimer.foreach(_.cancel())
registrationRetryTimer = None
} else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
- tryRegisterAllMasters()
+ logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
+ /**
+ * Re-register with the active master this worker has been communicating with. If there
+ * is none, then it means this worker is still bootstrapping and hasn't established a
+ * connection with a master yet, in which case we should re-register with all masters.
+ *
+ * It is important to re-register only with the active master during failures. Otherwise,
+ * if the worker unconditionally attempts to re-register with all masters, the following
+ * race condition may arise and cause a "duplicate worker" error detailed in SPARK-4592:
+ *
+ * (1) Master A fails and Worker attempts to reconnect to all masters
+ * (2) Master B takes over and notifies Worker
+ * (3) Worker responds by registering with Master B
+ * (4) Meanwhile, Worker's previous reconnection attempt reaches Master B,
+ * causing the same Worker to register with Master B twice
+ *
+ * Instead, if we only register with the known active master, we can assume that the
+ * old master must have died because another master has taken over. Note that this is
+ * still not safe if the old master recovers within this interval, but this is a much
+ * less likely scenario.
+ */
+ if (master != null) {
+ master ! RegisterWorker(
+ workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
+ } else {
+ // We are retrying the initial registration
+ tryRegisterAllMasters()
+ }
+ // We have exceeded the initial registration retry threshold
+ // All retries from now on should use a higher interval
if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
registrationRetryTimer.foreach(_.cancel())
registrationRetryTimer = Some {
context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL,
- PROLONGED_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster)
+ PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
}
}
} else {
@@ -216,7 +262,7 @@ private[spark] class Worker(
connectionAttemptCount = 0
registrationRetryTimer = Some {
context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL,
- INITIAL_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster)
+ INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
}
case Some(_) =>
logInfo("Not spawning another attempt to register with the master, since there is an" +
@@ -253,7 +299,7 @@ private[spark] class Worker(
val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir)
dir.isDirectory && !isAppStillRunning &&
!Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS)
- }.foreach { dir =>
+ }.foreach { dir =>
logInfo(s"Removing directory: ${dir.getPath}")
Utils.deleteRecursively(dir)
}
@@ -298,8 +344,29 @@ private[spark] class Worker(
throw new IOException("Failed to create directory " + executorDir)
}
- val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
- self, workerId, host, sparkHome, executorDir, akkaUrl, conf, ExecutorState.LOADING)
+ // Create local dirs for the executor. These are passed to the executor via the
+ // SPARK_LOCAL_DIRS environment variable, and deleted by the Worker when the
+ // application finishes.
+ val appLocalDirs = appDirectories.get(appId).getOrElse {
+ Utils.getOrCreateLocalRootDirs(conf).map { dir =>
+ Utils.createDirectory(dir).getAbsolutePath()
+ }.toSeq
+ }
+ appDirectories(appId) = appLocalDirs
+ val manager = new ExecutorRunner(
+ appId,
+ execId,
+ appDesc.copy(command = Worker.maybeUpdateSSLSettings(appDesc.command, conf)),
+ cores_,
+ memory_,
+ self,
+ workerId,
+ host,
+ sparkHome,
+ executorDir,
+ akkaUrl,
+ conf,
+ appLocalDirs, ExecutorState.LOADING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -336,6 +403,7 @@ private[spark] class Worker(
message.map(" message " + _).getOrElse("") +
exitStatus.map(" exitStatus " + _).getOrElse(""))
}
+ maybeCleanupApplication(appId)
}
case KillExecutor(masterUrl, appId, execId) =>
@@ -354,7 +422,14 @@ private[spark] class Worker(
case LaunchDriver(driverId, driverDesc) => {
logInfo(s"Asked to launch driver $driverId")
- val driver = new DriverRunner(conf, driverId, workDir, sparkHome, driverDesc, self, akkaUrl)
+ val driver = new DriverRunner(
+ conf,
+ driverId,
+ workDir,
+ sparkHome,
+ driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)),
+ self,
+ akkaUrl)
drivers(driverId) = driver
driver.start()
@@ -396,12 +471,18 @@ private[spark] class Worker(
logInfo(s"$x Disassociated !")
masterDisconnected()
- case RequestWorkerState => {
+ case RequestWorkerState =>
sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
finishedExecutors.values.toList, drivers.values.toList,
finishedDrivers.values.toList, activeMasterUrl, cores, memory,
coresUsed, memoryUsed, activeMasterWebUiUrl)
- }
+
+ case ReregisterWithMaster =>
+ reregisterWithMaster()
+
+ case ApplicationFinished(id) =>
+ finishedApps += id
+ maybeCleanupApplication(id)
}
private def masterDisconnected() {
@@ -410,6 +491,19 @@ private[spark] class Worker(
registerWithMaster()
}
+ private def maybeCleanupApplication(id: String): Unit = {
+ val shouldCleanup = finishedApps.contains(id) && !executors.values.exists(_.appId == id)
+ if (shouldCleanup) {
+ finishedApps -= id
+ appDirectories.remove(id).foreach { dirList =>
+ logInfo(s"Cleaning up local directories for application $id")
+ dirList.foreach { dir =>
+ Utils.deleteRecursively(new File(dir))
+ }
+ }
+ }
+ }
+
def generateWorkerId(): String = {
"worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port)
}
@@ -419,6 +513,7 @@ private[spark] class Worker(
registrationRetryTimer.foreach(_.cancel())
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
+ shuffleService.stop()
webUi.stop()
metricsSystem.stop()
}
@@ -441,7 +536,8 @@ private[spark] object Worker extends Logging {
cores: Int,
memory: Int,
masterUrls: Array[String],
- workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ workDir: String,
+ workerNumber: Option[Int] = None): (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val conf = new SparkConf
@@ -450,9 +546,32 @@ private[spark] object Worker extends Logging {
val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
conf = conf, securityManager = securityMgr)
+ val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
+ masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
(actorSystem, boundPort)
}
+ private[spark] def isUseLocalNodeSSLConfig(cmd: Command): Boolean = {
+ val pattern = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r
+ val result = cmd.javaOpts.collectFirst {
+ case pattern(_result) => _result.toBoolean
+ }
+ result.getOrElse(false)
+ }
+
+ private[spark] def maybeUpdateSSLSettings(cmd: Command, conf: SparkConf): Command = {
+ val prefix = "spark.ssl."
+ val useNLC = "spark.ssl.useNodeLocalConf"
+ if (isUseLocalNodeSSLConfig(cmd)) {
+ val newJavaOpts = cmd.javaOpts
+ .filter(opt => !opt.startsWith(s"-D$prefix")) ++
+ conf.getAll.collect { case (key, value) if key.startsWith(prefix) => s"-D$key=$value" } :+
+ s"-D$useNLC=true"
+ cmd.copy(javaOpts = newJavaOpts)
+ } else {
+ cmd
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index b07942a9ca729..7ac81a2d87efd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -50,7 +50,6 @@ class WorkerWebUI(
attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static"))
attachHandler(createServletHandler("/log",
(request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr))
- worker.metricsSystem.getServletHandlers.foreach(attachHandler)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index c40a3e16675ad..bc72c8970319c 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -38,7 +38,7 @@ private[spark] class CoarseGrainedExecutorBackend(
executorId: String,
hostPort: String,
cores: Int,
- sparkProperties: Seq[(String, String)])
+ env: SparkEnv)
extends Actor with ActorLogReceive with ExecutorBackend with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
@@ -56,9 +56,8 @@ private[spark] class CoarseGrainedExecutorBackend(
override def receiveWithLogging = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
- // Make this host instead of hostPort ?
- executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties,
- false)
+ val (hostname, _) = Utils.parseHostPort(hostPort)
+ executor = new Executor(executorId, hostname, env, isLocal = false)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -69,10 +68,11 @@ private[spark] class CoarseGrainedExecutorBackend(
logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
- val ser = SparkEnv.get.closureSerializer.newInstance()
+ val ser = env.closureSerializer.newInstance()
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
- executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask)
+ executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
+ taskDesc.name, taskDesc.serializedTask)
}
case KillTask(taskId, _, interruptThread) =>
@@ -84,8 +84,12 @@ private[spark] class CoarseGrainedExecutorBackend(
}
case x: DisassociatedEvent =>
- logError(s"Driver $x disassociated! Shutting down.")
- System.exit(1)
+ if (x.remoteAddress == driver.anchorPath.address) {
+ logError(s"Driver $x disassociated! Shutting down.")
+ System.exit(1)
+ } else {
+ logWarning(s"Received irrelevant DisassociatedEvent $x")
+ }
case StopExecutor =>
logInfo("Driver commanded a shutdown")
@@ -119,7 +123,11 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val executorConf = new SparkConf
val port = executorConf.getInt("spark.executor.port", 0)
val (fetcher, _) = AkkaUtils.createActorSystem(
- "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf))
+ "driverPropsFetcher",
+ hostname,
+ port,
+ executorConf,
+ new SecurityManager(executorConf))
val driver = fetcher.actorSelection(driverUrl)
val timeout = AkkaUtils.askTimeout(executorConf)
val fut = Patterns.ask(driver, RetrieveSparkProps, timeout)
@@ -127,20 +135,33 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
- // Create a new ActorSystem using driver's Spark properties to run the backend.
- val driverConf = new SparkConf().setAll(props)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf))
- // set it
+ // Create SparkEnv using properties we fetched from the driver.
+ val driverConf = new SparkConf()
+ for ((key, value) <- props) {
+ // this is required for SSL in standalone mode
+ if (SparkConf.isExecutorStartupConf(key)) {
+ driverConf.setIfMissing(key, value)
+ } else {
+ driverConf.set(key, value)
+ }
+ }
+ val env = SparkEnv.createExecutorEnv(
+ driverConf, executorId, hostname, port, cores, isLocal = false)
+
+ // SparkEnv sets spark.driver.port so it shouldn't be 0 anymore.
+ val boundPort = env.conf.getInt("spark.executor.port", 0)
+ assert(boundPort != 0)
+
+ // Start the CoarseGrainedExecutorBackend actor.
val sparkHostPort = hostname + ":" + boundPort
- actorSystem.actorOf(
+ env.actorSystem.actorOf(
Props(classOf[CoarseGrainedExecutorBackend],
- driverUrl, executorId, sparkHostPort, cores, props),
+ driverUrl, executorId, sparkHostPort, cores, env),
name = "Executor")
workerUrl.foreach { url =>
- actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
+ env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
}
- actorSystem.awaitTermination()
+ env.actorSystem.awaitTermination()
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 616c7e6a46368..312bb3a1daaa3 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -26,23 +26,29 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
+import akka.actor.Props
+
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
+ * In coarse-grained mode, an existing actor system is provided.
*/
private[spark] class Executor(
executorId: String,
- slaveHostname: String,
- properties: Seq[(String, String)],
+ executorHostname: String,
+ env: SparkEnv,
isLocal: Boolean = false)
extends Logging
{
+
+ logInfo(s"Starting executor ID $executorId on host $executorHostname")
+
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -50,43 +56,36 @@ private[spark] class Executor(
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
+ private val conf = env.conf
+
@volatile private var isStopped = false
// No ip or host:port - just hostname
- Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ Utils.checkHost(executorHostname, "Expected executed slave to be a hostname")
// must not have port specified.
- assert (0 == Utils.parseHostPort(slaveHostname)._2)
+ assert (0 == Utils.parseHostPort(executorHostname)._2)
// Make sure the local hostname we report matches the cluster scheduler's name for this host
- Utils.setCustomHostname(slaveHostname)
-
- // Set spark.* properties from executor arg
- val conf = new SparkConf(true)
- conf.setAll(properties)
+ Utils.setCustomHostname(executorHostname)
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(ExecutorUncaughtExceptionHandler)
+ Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
}
val executorSource = new ExecutorSource(this, executorId)
- // Initialize Spark environment (using system properties read above)
- conf.set("spark.executor.id", "executor." + executorId)
- private val env = {
- if (!isLocal) {
- val _env = SparkEnv.create(conf, executorId, slaveHostname, 0,
- isDriver = false, isLocal = false)
- SparkEnv.set(_env)
- _env.metricsSystem.registerSource(executorSource)
- _env
- } else {
- SparkEnv.get
- }
+ if (!isLocal) {
+ env.metricsSystem.registerSource(executorSource)
+ env.blockManager.initialize(conf.getAppId)
}
+ // Create an actor for receiving RPCs from the driver
+ private val executorActor = env.actorSystem.actorOf(
+ Props(new ExecutorActor(executorId)), "ExecutorActor")
+
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
@@ -99,6 +98,9 @@ private[spark] class Executor(
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+ // Limit of bytes for total size of results (default is 1GB)
+ private val maxResultSize = Utils.getMaxResultSize(conf)
+
// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
@@ -108,8 +110,13 @@ private[spark] class Executor(
startDriverHeartbeater()
def launchTask(
- context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
- val tr = new TaskRunner(context, taskId, taskName, serializedTask)
+ context: ExecutorBackend,
+ taskId: Long,
+ attemptNumber: Int,
+ taskName: String,
+ serializedTask: ByteBuffer) {
+ val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
+ serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
@@ -123,6 +130,7 @@ private[spark] class Executor(
def stop() {
env.metricsSystem.report()
+ env.actorSystem.stop(executorActor)
isStopped = true
threadPool.shutdown()
if (!isLocal) {
@@ -130,13 +138,20 @@ private[spark] class Executor(
}
}
+ private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+
class TaskRunner(
- execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
+ execBackend: ExecutorBackend,
+ val taskId: Long,
+ val attemptNumber: Int,
+ taskName: String,
+ serializedTask: ByteBuffer)
extends Runnable {
@volatile private var killed = false
@volatile var task: Task[Any] = _
@volatile var attemptedTask: Option[Task[Any]] = None
+ @volatile var startGCTime: Long = _
def kill(interruptThread: Boolean) {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
@@ -147,17 +162,15 @@ private[spark] class Executor(
}
override def run() {
- val startTime = System.currentTimeMillis()
+ val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
- val ser = SparkEnv.get.closureSerializer.newInstance()
+ val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
- def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
- val startGCTime = gcTime
+ startGCTime = gcTime
try {
- Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
@@ -178,7 +191,7 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
- val value = task.run(taskId.toInt)
+ val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
@@ -186,16 +199,16 @@ private[spark] class Executor(
throw new TaskKilledException
}
- val resultSer = SparkEnv.get.serializer.newInstance()
+ val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
- m.executorDeserializeTime = taskStart - startTime
- m.executorRunTime = taskFinish - taskStart
- m.jvmGCTime = gcTime - startGCTime
- m.resultSerializationTime = afterSerialization - beforeSerialization
+ m.setExecutorDeserializeTime(taskStart - deserializeStartTime)
+ m.setExecutorRunTime(taskFinish - taskStart)
+ m.setJvmGCTime(gcTime - startGCTime)
+ m.setResultSerializationTime(afterSerialization - beforeSerialization)
}
val accumUpdates = Accumulators.values
@@ -205,25 +218,27 @@ private[spark] class Executor(
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
- val (serializedResult, directSend) = {
- if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
+ val serializedResult = {
+ if (maxResultSize > 0 && resultSize > maxResultSize) {
+ logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
+ s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
+ s"dropping it.")
+ ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
+ } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
- (ser.serialize(new IndirectTaskResult[Any](blockId)), false)
+ logInfo(
+ s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
+ ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
- (serializedDirectResult, true)
+ logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
+ serializedDirectResult
}
}
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
- if (directSend) {
- logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
- } else {
- logInfo(
- s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
- }
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
@@ -244,16 +259,16 @@ private[spark] class Executor(
val serviceTime = System.currentTimeMillis() - taskStart
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
- m.executorRunTime = serviceTime
- m.jvmGCTime = gcTime - startGCTime
+ m.setExecutorRunTime(serviceTime)
+ m.setJvmGCTime(gcTime - startGCTime)
}
- val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics)
+ val reason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
- ExecutorUncaughtExceptionHandler.uncaughtException(t)
+ SparkUncaughtExceptionHandler.uncaughtException(t)
}
}
} finally {
@@ -261,6 +276,8 @@ private[spark] class Executor(
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
+ // Release memory used by this thread for accumulators
+ Accumulators.clear()
runningTasks.remove(taskId)
}
}
@@ -317,19 +334,21 @@ private[spark] class Executor(
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+ lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
synchronized {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager,
- hadoopConf)
+ // Fetch file with useCache mode, close cache for local mode.
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager,
- hadoopConf)
+ // Fetch file with useCache mode, close cache for local mode.
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
@@ -356,10 +375,15 @@ private[spark] class Executor(
while (!isStopped) {
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
+ val curGCTime = gcTime
+
for (taskRunner <- runningTasks.values()) {
- if (!taskRunner.attemptedTask.isEmpty) {
+ if (taskRunner.attemptedTask.nonEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
- metrics.updateShuffleReadMetrics
+ metrics.updateShuffleReadMetrics()
+ metrics.updateInputMetrics()
+ metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
+
if (isLocal) {
// JobProgressListener will hold an reference of it during
// onExecutorMetricsUpdate(), then JobProgressListener can not see
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
new file mode 100644
index 0000000000000..41925f7e97e84
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import akka.actor.Actor
+import org.apache.spark.Logging
+
+import org.apache.spark.util.{Utils, ActorLogReceive}
+
+/**
+ * Driver -> Executor message to trigger a thread dump.
+ */
+private[spark] case object TriggerThreadDump
+
+/**
+ * Actor that runs inside of executors to enable driver -> executor RPC.
+ */
+private[spark]
+class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging {
+
+ override def receiveWithLogging = {
+ case TriggerThreadDump =>
+ sender ! Utils.getThreadDump()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
index 38be2c58b333f..52862ae0ca5e4 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
@@ -17,6 +17,8 @@
package org.apache.spark.executor
+import org.apache.spark.util.SparkExitCode._
+
/**
* These are exit codes that executors should use to provide the master with information about
* executor failures assuming that cluster management framework can capture the exit codes (but
@@ -27,16 +29,6 @@ package org.apache.spark.executor
*/
private[spark]
object ExecutorExitCode {
- /** The default uncaught exception handler was reached. */
- val UNCAUGHT_EXCEPTION = 50
-
- /** The default uncaught exception handler was called and an exception was encountered while
- logging the exception. */
- val UNCAUGHT_EXCEPTION_TWICE = 51
-
- /** The default uncaught exception handler was reached, and the uncaught exception was an
- OutOfMemoryError. */
- val OOM = 52
/** DiskStore failed to create a local temporary directory after many attempts. */
val DISK_STORE_FAILED_TO_CREATE_DIR = 53
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala
deleted file mode 100644
index b0e984c03964c..0000000000000
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor
-
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
-
-/**
- * The default uncaught exception handler for Executors terminates the whole process, to avoid
- * getting into a bad state indefinitely. Since Executors are relatively lightweight, it's better
- * to fail fast when things go wrong.
- */
-private[spark] object ExecutorUncaughtExceptionHandler
- extends Thread.UncaughtExceptionHandler with Logging {
-
- override def uncaughtException(thread: Thread, exception: Throwable) {
- try {
- logError("Uncaught exception in thread " + thread, exception)
-
- // We may have been called from a shutdown hook. If so, we must not call System.exit().
- // (If we do, we will deadlock.)
- if (!Utils.inShutdown()) {
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(ExecutorExitCode.OOM)
- } else {
- System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
- }
- }
- } catch {
- case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
- case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
- }
- }
-
- def uncaughtException(exception: Throwable) {
- uncaughtException(Thread.currentThread(), exception)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index bca0b152268ad..cfd672e1d8a97 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -19,13 +19,16 @@ package org.apache.spark.executor
import java.nio.ByteBuffer
+import scala.collection.JavaConversions._
+
import org.apache.mesos.protobuf.ByteString
-import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver, MesosNativeLibrary}
+import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver}
import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
-import org.apache.spark.{Logging, TaskState}
+import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData}
import org.apache.spark.util.{SignalLogger, Utils}
private[spark] class MesosExecutorBackend
@@ -50,22 +53,39 @@ private[spark] class MesosExecutorBackend
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
- logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
+
+ // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend.
+ val cpusPerTask = executorInfo.getResourcesList
+ .find(_.getName == "cpus")
+ .map(_.getScalar.getValue.toInt)
+ .getOrElse(0)
+ val executorId = executorInfo.getExecutorId.getValue
+
+ logInfo(s"Registered with Mesos as executor ID $executorId with $cpusPerTask cpus")
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++
Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue))
+ val conf = new SparkConf(loadDefaults = true).setAll(properties)
+ val port = conf.getInt("spark.executor.port", 0)
+ val env = SparkEnv.createExecutorEnv(
+ conf, executorId, slaveInfo.getHostname, port, cpusPerTask, isLocal = false)
+
executor = new Executor(
- executorInfo.getExecutorId.getValue,
+ executorId,
slaveInfo.getHostname,
- properties)
+ env)
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
val taskId = taskInfo.getTaskId.getValue.toLong
+ val taskData = MesosTaskLaunchData.fromByteString(taskInfo.getData)
if (executor == null) {
logError("Received launchTask but executor was null")
} else {
- executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer)
+ SparkHadoopUtil.get.runAsSparkUser { () =>
+ executor.launchTask(this, taskId = taskId, attemptNumber = taskData.attemptNumber,
+ taskInfo.getName, taskData.serializedTask)
+ }
}
}
@@ -97,11 +117,8 @@ private[spark] class MesosExecutorBackend
private[spark] object MesosExecutorBackend extends Logging {
def main(args: Array[String]) {
SignalLogger.register(log)
- SparkHadoopUtil.get.runAsSparkUser { () =>
- MesosNativeLibrary.load()
- // Create a new Executor and start it running
- val runner = new MesosExecutorBackend()
- new MesosExecutorDriver(runner).run()
- }
+ // Create a new Executor and start it running
+ val runner = new MesosExecutorBackend()
+ new MesosExecutorDriver(runner).run()
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 3e49b6235aff3..97912c68c5982 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,6 +17,10 @@
package org.apache.spark.executor
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.executor.DataReadMethod.DataReadMethod
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
@@ -39,48 +43,84 @@ class TaskMetrics extends Serializable {
/**
* Host's name the task runs on
*/
- var hostname: String = _
-
+ private var _hostname: String = _
+ def hostname = _hostname
+ private[spark] def setHostname(value: String) = _hostname = value
+
/**
* Time taken on the executor to deserialize this task
*/
- var executorDeserializeTime: Long = _
-
+ private var _executorDeserializeTime: Long = _
+ def executorDeserializeTime = _executorDeserializeTime
+ private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value
+
+
/**
* Time the executor spends actually running the task (including fetching shuffle data)
*/
- var executorRunTime: Long = _
-
+ private var _executorRunTime: Long = _
+ def executorRunTime = _executorRunTime
+ private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value
+
/**
* The number of bytes this task transmitted back to the driver as the TaskResult
*/
- var resultSize: Long = _
+ private var _resultSize: Long = _
+ def resultSize = _resultSize
+ private[spark] def setResultSize(value: Long) = _resultSize = value
+
/**
* Amount of time the JVM spent in garbage collection while executing this task
*/
- var jvmGCTime: Long = _
+ private var _jvmGCTime: Long = _
+ def jvmGCTime = _jvmGCTime
+ private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value
/**
* Amount of time spent serializing the task result
*/
- var resultSerializationTime: Long = _
+ private var _resultSerializationTime: Long = _
+ def resultSerializationTime = _resultSerializationTime
+ private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value
/**
* The number of in-memory bytes spilled by this task
*/
- var memoryBytesSpilled: Long = _
+ private var _memoryBytesSpilled: Long = _
+ def memoryBytesSpilled = _memoryBytesSpilled
+ private[spark] def incMemoryBytesSpilled(value: Long) = _memoryBytesSpilled += value
+ private[spark] def decMemoryBytesSpilled(value: Long) = _memoryBytesSpilled -= value
/**
* The number of on-disk bytes spilled by this task
*/
- var diskBytesSpilled: Long = _
+ private var _diskBytesSpilled: Long = _
+ def diskBytesSpilled = _diskBytesSpilled
+ def incDiskBytesSpilled(value: Long) = _diskBytesSpilled += value
+ def decDiskBytesSpilled(value: Long) = _diskBytesSpilled -= value
/**
* If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read
* are stored here.
*/
- var inputMetrics: Option[InputMetrics] = None
+ private var _inputMetrics: Option[InputMetrics] = None
+
+ def inputMetrics = _inputMetrics
+
+ /**
+ * This should only be used when recreating TaskMetrics, not when updating input metrics in
+ * executors
+ */
+ private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) {
+ _inputMetrics = inputMetrics
+ }
+
+ /**
+ * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much
+ * data was written are stored here.
+ */
+ var outputMetrics: Option[OutputMetrics] = None
/**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
@@ -127,19 +167,47 @@ class TaskMetrics extends Serializable {
readMetrics
}
+ /**
+ * Returns the input metrics object that the task should use. Currently, if
+ * there exists an input metric with the same readMethod, we return that one
+ * so the caller can accumulate bytes read. If the readMethod is different
+ * than previously seen by this task, we return a new InputMetric but don't
+ * record it.
+ *
+ * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed,
+ * we can store all the different inputMetrics (one per readMethod).
+ */
+ private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod):
+ InputMetrics =synchronized {
+ _inputMetrics match {
+ case None =>
+ val metrics = new InputMetrics(readMethod)
+ _inputMetrics = Some(metrics)
+ metrics
+ case Some(metrics @ InputMetrics(method)) if method == readMethod =>
+ metrics
+ case Some(InputMetrics(method)) =>
+ new InputMetrics(readMethod)
+ }
+ }
+
/**
* Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
*/
private[spark] def updateShuffleReadMetrics() = synchronized {
val merged = new ShuffleReadMetrics()
for (depMetrics <- depsShuffleReadMetrics) {
- merged.fetchWaitTime += depMetrics.fetchWaitTime
- merged.localBlocksFetched += depMetrics.localBlocksFetched
- merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
- merged.remoteBytesRead += depMetrics.remoteBytesRead
+ merged.incFetchWaitTime(depMetrics.fetchWaitTime)
+ merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
+ merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
+ merged.incRemoteBytesRead(depMetrics.remoteBytesRead)
}
_shuffleReadMetrics = Some(merged)
}
+
+ private[spark] def updateInputMetrics() = synchronized {
+ inputMetrics.foreach(_.updateBytesRead())
+ }
}
private[spark] object TaskMetrics {
@@ -157,51 +225,115 @@ object DataReadMethod extends Enumeration with Serializable {
val Memory, Disk, Hadoop, Network = Value
}
+/**
+ * :: DeveloperApi ::
+ * Method by which output data was written.
+ */
+@DeveloperApi
+object DataWriteMethod extends Enumeration with Serializable {
+ type DataWriteMethod = Value
+ val Hadoop = Value
+}
+
/**
* :: DeveloperApi ::
* Metrics about reading input data.
*/
@DeveloperApi
case class InputMetrics(readMethod: DataReadMethod.Value) {
+
+ private val _bytesRead: AtomicLong = new AtomicLong()
+
/**
* Total bytes read.
*/
- var bytesRead: Long = 0L
-}
+ def bytesRead: Long = _bytesRead.get()
+ @volatile @transient var bytesReadCallback: Option[() => Long] = None
+ /**
+ * Adds additional bytes read for this read method.
+ */
+ def addBytesRead(bytes: Long) = {
+ _bytesRead.addAndGet(bytes)
+ }
+
+ /**
+ * Invoke the bytesReadCallback and mutate bytesRead.
+ */
+ def updateBytesRead() {
+ bytesReadCallback.foreach { c =>
+ _bytesRead.set(c())
+ }
+ }
+
+ /**
+ * Register a function that can be called to get up-to-date information on how many bytes the task
+ * has read from an input source.
+ */
+ def setBytesReadCallback(f: Option[() => Long]) {
+ bytesReadCallback = f
+ }
+}
/**
* :: DeveloperApi ::
- * Metrics pertaining to shuffle data read in a given task.
+ * Metrics about writing output data.
*/
@DeveloperApi
-class ShuffleReadMetrics extends Serializable {
+case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
/**
- * Number of blocks fetched in this shuffle by this task (remote or local)
+ * Total bytes written
*/
- def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched
+ private var _bytesWritten: Long = _
+ def bytesWritten = _bytesWritten
+ private[spark] def setBytesWritten(value : Long) = _bytesWritten = value
+}
+/**
+ * :: DeveloperApi ::
+ * Metrics pertaining to shuffle data read in a given task.
+ */
+@DeveloperApi
+class ShuffleReadMetrics extends Serializable {
/**
* Number of remote blocks fetched in this shuffle by this task
*/
- var remoteBlocksFetched: Int = _
-
+ private var _remoteBlocksFetched: Int = _
+ def remoteBlocksFetched = _remoteBlocksFetched
+ private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value
+ private[spark] def defRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value
+
/**
* Number of local blocks fetched in this shuffle by this task
*/
- var localBlocksFetched: Int = _
+ private var _localBlocksFetched: Int = _
+ def localBlocksFetched = _localBlocksFetched
+ private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value
+ private[spark] def defLocalBlocksFetched(value: Int) = _localBlocksFetched -= value
+
/**
* Time the task spent waiting for remote shuffle blocks. This only includes the time
* blocking on shuffle input data. For instance if block B is being fetched while the task is
* still not finished processing block A, it is not considered to be blocking on block B.
*/
- var fetchWaitTime: Long = _
-
+ private var _fetchWaitTime: Long = _
+ def fetchWaitTime = _fetchWaitTime
+ private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value
+ private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value
+
/**
* Total number of remote bytes read from the shuffle by this task
*/
- var remoteBytesRead: Long = _
+ private var _remoteBytesRead: Long = _
+ def remoteBytesRead = _remoteBytesRead
+ private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value
+ private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value
+
+ /**
+ * Number of blocks fetched in this shuffle by this task (remote or local)
+ */
+ def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched
}
/**
@@ -213,10 +345,18 @@ class ShuffleWriteMetrics extends Serializable {
/**
* Number of bytes written for the shuffle by this task
*/
- @volatile var shuffleBytesWritten: Long = _
-
+ @volatile private var _shuffleBytesWritten: Long = _
+ def shuffleBytesWritten = _shuffleBytesWritten
+ private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value
+ private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value
+
/**
* Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
*/
- @volatile var shuffleWriteTime: Long = _
+ @volatile private var _shuffleWriteTime: Long = _
+ def shuffleWriteTime= _shuffleWriteTime
+ private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value
+ private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value
+
+
}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
new file mode 100644
index 0000000000000..c219d21fbefa9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+import org.apache.spark.deploy.SparkHadoopUtil
+
+/**
+ * Custom Input Format for reading and splitting flat binary files that contain records,
+ * each of which are a fixed size in bytes. The fixed record size is specified through
+ * a parameter recordLength in the Hadoop configuration.
+ */
+private[spark] object FixedLengthBinaryInputFormat {
+ /** Property name to set in Hadoop JobConfs for record length */
+ val RECORD_LENGTH_PROPERTY = "org.apache.spark.input.FixedLengthBinaryInputFormat.recordLength"
+
+ /** Retrieves the record length property from a Hadoop configuration */
+ def getRecordLength(context: JobContext): Int = {
+ SparkHadoopUtil.get.getConfigurationFromJobContext(context).get(RECORD_LENGTH_PROPERTY).toInt
+ }
+}
+
+private[spark] class FixedLengthBinaryInputFormat
+ extends FileInputFormat[LongWritable, BytesWritable] {
+
+ private var recordLength = -1
+
+ /**
+ * Override of isSplitable to ensure initial computation of the record length
+ */
+ override def isSplitable(context: JobContext, filename: Path): Boolean = {
+ if (recordLength == -1) {
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ }
+ if (recordLength <= 0) {
+ println("record length is less than 0, file cannot be split")
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
+ * This input format overrides computeSplitSize() to make sure that each split
+ * only contains full records. Each InputSplit passed to FixedLengthBinaryRecordReader
+ * will start at the first byte of a record, and the last byte will the last byte of a record.
+ */
+ override def computeSplitSize(blockSize: Long, minSize: Long, maxSize: Long): Long = {
+ val defaultSize = super.computeSplitSize(blockSize, minSize, maxSize)
+ // If the default size is less than the length of a record, make it equal to it
+ // Otherwise, make sure the split size is as close to possible as the default size,
+ // but still contains a complete set of records, with the first record
+ // starting at the first byte in the split and the last record ending with the last byte
+ if (defaultSize < recordLength) {
+ recordLength.toLong
+ } else {
+ (Math.floor(defaultSize / recordLength) * recordLength).toLong
+ }
+ }
+
+ /**
+ * Create a FixedLengthBinaryRecordReader
+ */
+ override def createRecordReader(split: InputSplit, context: TaskAttemptContext)
+ : RecordReader[LongWritable, BytesWritable] = {
+ new FixedLengthBinaryRecordReader
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
new file mode 100644
index 0000000000000..67a96925da019
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.IOException
+
+import org.apache.hadoop.fs.FSDataInputStream
+import org.apache.hadoop.io.compress.CompressionCodecFactory
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.spark.deploy.SparkHadoopUtil
+
+/**
+ * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat.
+ * It uses the record length set in FixedLengthBinaryInputFormat to
+ * read one record at a time from the given InputSplit.
+ *
+ * Each call to nextKeyValue() updates the LongWritable key and BytesWritable value.
+ *
+ * key = record index (Long)
+ * value = the record itself (BytesWritable)
+ */
+private[spark] class FixedLengthBinaryRecordReader
+ extends RecordReader[LongWritable, BytesWritable] {
+
+ private var splitStart: Long = 0L
+ private var splitEnd: Long = 0L
+ private var currentPosition: Long = 0L
+ private var recordLength: Int = 0
+ private var fileInputStream: FSDataInputStream = null
+ private var recordKey: LongWritable = null
+ private var recordValue: BytesWritable = null
+
+ override def close() {
+ if (fileInputStream != null) {
+ fileInputStream.close()
+ }
+ }
+
+ override def getCurrentKey: LongWritable = {
+ recordKey
+ }
+
+ override def getCurrentValue: BytesWritable = {
+ recordValue
+ }
+
+ override def getProgress: Float = {
+ splitStart match {
+ case x if x == splitEnd => 0.0.toFloat
+ case _ => Math.min(
+ ((currentPosition - splitStart) / (splitEnd - splitStart)).toFloat, 1.0
+ ).toFloat
+ }
+ }
+
+ override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) {
+ // the file input
+ val fileSplit = inputSplit.asInstanceOf[FileSplit]
+
+ // the byte position this fileSplit starts at
+ splitStart = fileSplit.getStart
+
+ // splitEnd byte marker that the fileSplit ends at
+ splitEnd = splitStart + fileSplit.getLength
+
+ // the actual file we will be reading from
+ val file = fileSplit.getPath
+ // job configuration
+ val job = SparkHadoopUtil.get.getConfigurationFromJobContext(context)
+ // check compression
+ val codec = new CompressionCodecFactory(job).getCodec(file)
+ if (codec != null) {
+ throw new IOException("FixedLengthRecordReader does not support reading compressed files")
+ }
+ // get the record length
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ // get the filesystem
+ val fs = file.getFileSystem(job)
+ // open the File
+ fileInputStream = fs.open(file)
+ // seek to the splitStart position
+ fileInputStream.seek(splitStart)
+ // set our current position
+ currentPosition = splitStart
+ }
+
+ override def nextKeyValue(): Boolean = {
+ if (recordKey == null) {
+ recordKey = new LongWritable()
+ }
+ // the key is a linear index of the record, given by the
+ // position the record starts divided by the record length
+ recordKey.set(currentPosition / recordLength)
+ // the recordValue to place the bytes into
+ if (recordValue == null) {
+ recordValue = new BytesWritable(new Array[Byte](recordLength))
+ }
+ // read a record if the currentPosition is less than the split end
+ if (currentPosition < splitEnd) {
+ // setup a buffer to store the record
+ val buffer = recordValue.getBytes
+ fileInputStream.readFully(buffer)
+ // update our current position
+ currentPosition = currentPosition + recordLength
+ // return true
+ return true
+ }
+ false
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
new file mode 100644
index 0000000000000..593a62b3e3b32
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -0,0 +1,220 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import scala.collection.JavaConversions._
+
+import com.google.common.io.ByteStreams
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.deploy.SparkHadoopUtil
+
+/**
+ * A general format for reading whole files in as streams, byte arrays,
+ * or other functions to be added
+ */
+private[spark] abstract class StreamFileInputFormat[T]
+ extends CombineFileInputFormat[String, T]
+{
+ override protected def isSplitable(context: JobContext, file: Path): Boolean = false
+
+ /**
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API
+ * which is set through setMaxSplitSize
+ */
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
+ val files = listStatus(context)
+ val totalLen = files.map { file =>
+ if (file.isDir) 0L else file.getLen
+ }.sum
+
+ val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong
+ super.setMaxSplitSize(maxSplitSize)
+ }
+
+ def createRecordReader(split: InputSplit, taContext: TaskAttemptContext): RecordReader[String, T]
+
+}
+
+/**
+ * An abstract class of [[org.apache.hadoop.mapreduce.RecordReader RecordReader]]
+ * to reading files out as streams
+ */
+private[spark] abstract class StreamBasedRecordReader[T](
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends RecordReader[String, T] {
+
+ // True means the current file has been processed, then skip it.
+ private var processed = false
+
+ private var key = ""
+ private var value: T = null.asInstanceOf[T]
+
+ override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
+ override def close() = {}
+
+ override def getProgress = if (processed) 1.0f else 0.0f
+
+ override def getCurrentKey = key
+
+ override def getCurrentValue = value
+
+ override def nextKeyValue = {
+ if (!processed) {
+ val fileIn = new PortableDataStream(split, context, index)
+ value = parseStream(fileIn)
+ fileIn.close() // if it has not been open yet, close does nothing
+ key = fileIn.getPath
+ processed = true
+ true
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Parse the stream (and close it afterwards) and return the value as in type T
+ * @param inStream the stream to be read in
+ * @return the data formatted as
+ */
+ def parseStream(inStream: PortableDataStream): T
+}
+
+/**
+ * Reads the record in directly as a stream for other objects to manipulate and handle
+ */
+private[spark] class StreamRecordReader(
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends StreamBasedRecordReader[PortableDataStream](split, context, index) {
+
+ def parseStream(inStream: PortableDataStream): PortableDataStream = inStream
+}
+
+/**
+ * The format for the PortableDataStream files
+ */
+private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] {
+ override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = {
+ new CombineFileRecordReader[String, PortableDataStream](
+ split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader])
+ }
+}
+
+/**
+ * A class that allows DataStreams to be serialized and moved around by not creating them
+ * until they need to be read
+ * @note TaskAttemptContext is not serializable resulting in the confBytes construct
+ * @note CombineFileSplit is not serializable resulting in the splitBytes construct
+ */
+@Experimental
+class PortableDataStream(
+ @transient isplit: CombineFileSplit,
+ @transient context: TaskAttemptContext,
+ index: Integer)
+ extends Serializable {
+
+ // transient forces file to be reopened after being serialization
+ // it is also used for non-serializable classes
+
+ @transient private var fileIn: DataInputStream = null
+ @transient private var isOpen = false
+
+ private val confBytes = {
+ val baos = new ByteArrayOutputStream()
+ SparkHadoopUtil.get.getConfigurationFromJobContext(context).
+ write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ private val splitBytes = {
+ val baos = new ByteArrayOutputStream()
+ isplit.write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ @transient private lazy val split = {
+ val bais = new ByteArrayInputStream(splitBytes)
+ val nsplit = new CombineFileSplit()
+ nsplit.readFields(new DataInputStream(bais))
+ nsplit
+ }
+
+ @transient private lazy val conf = {
+ val bais = new ByteArrayInputStream(confBytes)
+ val nconf = new Configuration()
+ nconf.readFields(new DataInputStream(bais))
+ nconf
+ }
+ /**
+ * Calculate the path name independently of opening the file
+ */
+ @transient private lazy val path = {
+ val pathp = split.getPath(index)
+ pathp.toString
+ }
+
+ /**
+ * Create a new DataInputStream from the split and context
+ */
+ def open(): DataInputStream = {
+ if (!isOpen) {
+ val pathp = split.getPath(index)
+ val fs = pathp.getFileSystem(conf)
+ fileIn = fs.open(pathp)
+ isOpen = true
+ }
+ fileIn
+ }
+
+ /**
+ * Read the file as a byte array
+ */
+ def toArray(): Array[Byte] = {
+ open()
+ val innerBuffer = ByteStreams.toByteArray(fileIn)
+ close()
+ innerBuffer
+ }
+
+ /**
+ * Close the file (if it is currently open)
+ */
+ def close() = {
+ if (isOpen) {
+ try {
+ fileIn.close()
+ isOpen = false
+ } catch {
+ case ioe: java.io.IOException => // do nothing
+ }
+ }
+ }
+
+ def getPath(): String = path
+}
+
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index 4cb450577796a..aaef7c74eea33 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -25,8 +25,6 @@ import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader
-import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
/**
* A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for
@@ -34,23 +32,26 @@ import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
* the value is the entire content of file.
*/
-private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] {
+private[spark] class WholeTextFileInputFormat
+ extends CombineFileInputFormat[String, String] with Configurable {
+
override protected def isSplitable(context: JobContext, file: Path): Boolean = false
override def createRecordReader(
split: InputSplit,
context: TaskAttemptContext): RecordReader[String, String] = {
- new CombineFileRecordReader[String, String](
- split.asInstanceOf[CombineFileSplit],
- context,
- classOf[WholeTextFileRecordReader])
+ val reader =
+ new ConfigurableCombineFileRecordReader(split, context, classOf[WholeTextFileRecordReader])
+ reader.setConf(getConf)
+ reader
}
/**
- * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API.
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API,
+ * which is set through setMaxSplitSize
*/
- def setMaxSplitSize(context: JobContext, minPartitions: Int) {
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
val files = listStatus(context)
val totalLen = files.map { file =>
if (file.isDir) 0L else file.getLen
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
index 3564ab2e2a162..31bde8a78f3c6 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
@@ -17,13 +17,28 @@
package org.apache.spark.input
+import org.apache.hadoop.conf.{Configuration, Configurable => HConfigurable}
import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.compress.CompressionCodecFactory
import org.apache.hadoop.mapreduce.InputSplit
-import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader}
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.spark.deploy.SparkHadoopUtil
+
+
+/**
+ * A trait to implement [[org.apache.hadoop.conf.Configurable Configurable]] interface.
+ */
+private[spark] trait Configurable extends HConfigurable {
+ private var conf: Configuration = _
+ def setConf(c: Configuration) {
+ conf = c
+ }
+ def getConf: Configuration = conf
+}
/**
* A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file
@@ -34,10 +49,11 @@ private[spark] class WholeTextFileRecordReader(
split: CombineFileSplit,
context: TaskAttemptContext,
index: Integer)
- extends RecordReader[String, String] {
+ extends RecordReader[String, String] with Configurable {
private[this] val path = split.getPath(index)
- private[this] val fs = path.getFileSystem(context.getConfiguration)
+ private[this] val fs = path.getFileSystem(
+ SparkHadoopUtil.get.getConfigurationFromJobContext(context))
// True means the current file has been processed, then skip it.
private[this] var processed = false
@@ -57,8 +73,16 @@ private[spark] class WholeTextFileRecordReader(
override def nextKeyValue(): Boolean = {
if (!processed) {
+ val conf = new Configuration
+ val factory = new CompressionCodecFactory(conf)
+ val codec = factory.getCodec(path) // infers from file ext.
val fileIn = fs.open(path)
- val innerBuffer = ByteStreams.toByteArray(fileIn)
+ val innerBuffer = if (codec != null) {
+ ByteStreams.toByteArray(codec.createInputStream(fileIn))
+ } else {
+ ByteStreams.toByteArray(fileIn)
+ }
+
value = new Text(innerBuffer).toString
Closeables.close(fileIn, false)
processed = true
@@ -68,3 +92,28 @@ private[spark] class WholeTextFileRecordReader(
}
}
}
+
+
+/**
+ * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader CombineFileRecordReader]]
+ * that can pass Hadoop Configuration to [[org.apache.hadoop.conf.Configurable Configurable]]
+ * RecordReaders.
+ */
+private[spark] class ConfigurableCombineFileRecordReader[K, V](
+ split: InputSplit,
+ context: TaskAttemptContext,
+ recordReaderClass: Class[_ <: RecordReader[K, V] with HConfigurable])
+ extends CombineFileRecordReader[K, V](
+ split.asInstanceOf[CombineFileSplit],
+ context,
+ recordReaderClass
+ ) with Configurable {
+
+ override def initNextRecordReader(): Boolean = {
+ val r = super.initNextRecordReader()
+ if (r) {
+ this.curReader.asInstanceOf[HConfigurable].setConf(getConf)
+ }
+ r
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 1ac7f4e448eb1..f856890d279f4 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -21,11 +21,12 @@ import java.io.{InputStream, OutputStream}
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream}
-import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream}
+import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
+import org.apache.spark.Logging
/**
* :: DeveloperApi ::
@@ -44,25 +45,33 @@ trait CompressionCodec {
def compressedInputStream(s: InputStream): InputStream
}
-
private[spark] object CompressionCodec {
+ private val configKey = "spark.io.compression.codec"
private val shortCompressionCodecNames = Map(
"lz4" -> classOf[LZ4CompressionCodec].getName,
"lzf" -> classOf[LZFCompressionCodec].getName,
"snappy" -> classOf[SnappyCompressionCodec].getName)
def createCodec(conf: SparkConf): CompressionCodec = {
- createCodec(conf, conf.get("spark.io.compression.codec", DEFAULT_COMPRESSION_CODEC))
+ createCodec(conf, conf.get(configKey, DEFAULT_COMPRESSION_CODEC))
}
def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
- val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader)
- .getConstructor(classOf[SparkConf])
- ctor.newInstance(conf).asInstanceOf[CompressionCodec]
+ val codec = try {
+ val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader)
+ .getConstructor(classOf[SparkConf])
+ Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
+ } catch {
+ case e: ClassNotFoundException => None
+ case e: IllegalArgumentException => None
+ }
+ codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available. " +
+ s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC"))
}
+ val FALLBACK_COMPRESSION_CODEC = "lzf"
val DEFAULT_COMPRESSION_CODEC = "snappy"
val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq
}
@@ -120,6 +129,12 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
@DeveloperApi
class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec {
+ try {
+ Snappy.getNativeLibraryVersion
+ } catch {
+ case e: Error => throw new IllegalArgumentException
+ }
+
override def compressedOutputStream(s: OutputStream): OutputStream = {
val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768)
new SnappyOutputStream(s, blockSize)
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
similarity index 79%
rename from core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
rename to core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index 0c47afae54c8b..21b782edd2a9e 100644
--- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -15,15 +15,24 @@
* limitations under the License.
*/
-package org.apache.hadoop.mapred
+package org.apache.spark.mapred
-private[apache]
+import java.lang.reflect.Modifier
+
+import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext}
+
+private[spark]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
"org.apache.hadoop.mapred.JobContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf],
classOf[org.apache.hadoop.mapreduce.JobID])
+ // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private.
+ // Make it accessible if it's not in order to access it.
+ if (!Modifier.isPublic(ctor.getModifiers)) {
+ ctor.setAccessible(true)
+ }
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
@@ -31,6 +40,10 @@ trait SparkHadoopMapRedUtil {
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
"org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
+ // See above
+ if (!Modifier.isPublic(ctor.getModifiers)) {
+ ctor.setAccessible(true)
+ }
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
similarity index 96%
rename from core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
rename to core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
index 1fca5729c6092..3340673f91156 100644
--- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-package org.apache.hadoop.mapreduce
+package org.apache.spark.mapreduce
import java.lang.{Boolean => JBoolean, Integer => JInteger}
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
-private[apache]
+private[spark]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 5dd67b0cbf683..83e8eb71260eb 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -76,22 +76,36 @@ private[spark] class MetricsSystem private (
private val sources = new mutable.ArrayBuffer[Source]
private val registry = new MetricRegistry()
+ private var running: Boolean = false
+
// Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui
private var metricsServlet: Option[MetricsServlet] = None
- /** Get any UI handlers used by this metrics system. */
- def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array())
+ /**
+ * Get any UI handlers used by this metrics system; can only be called after start().
+ */
+ def getServletHandlers = {
+ require(running, "Can only call getServletHandlers on a running MetricsSystem")
+ metricsServlet.map(_.getHandlers).getOrElse(Array())
+ }
metricsConfig.initialize()
def start() {
+ require(!running, "Attempting to start a MetricsSystem that is already running")
+ running = true
registerSources()
registerSinks()
sinks.foreach(_.start)
}
def stop() {
- sinks.foreach(_.stop)
+ if (running) {
+ sinks.foreach(_.stop)
+ } else {
+ logWarning("Stopping a MetricsSystem that is not running")
+ }
+ running = false
}
def report() {
@@ -107,7 +121,7 @@ private[spark] class MetricsSystem private (
* @return An unique metric name for each combination of
* application, executor/driver and metric source.
*/
- def buildRegistryName(source: Source): String = {
+ private[spark] def buildRegistryName(source: Source): String = {
val appId = conf.getOption("spark.app.id")
val executorId = conf.getOption("spark.executor.id")
val defaultName = MetricRegistry.name(source.sourceName)
@@ -116,8 +130,8 @@ private[spark] class MetricsSystem private (
if (appId.isDefined && executorId.isDefined) {
MetricRegistry.name(appId.get, executorId.get, source.sourceName)
} else {
- // Only Driver and Executor are set spark.app.id and spark.executor.id.
- // For instance, Master and Worker are not related to a specific application.
+ // Only Driver and Executor set spark.app.id and spark.executor.id.
+ // Other instance types, e.g. Master and Worker, are not related to a specific application.
val warningMsg = s"Using default name $defaultName for source because %s is not set."
if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) }
if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) }
@@ -144,7 +158,7 @@ private[spark] class MetricsSystem private (
})
}
- def registerSources() {
+ private def registerSources() {
val instConfig = metricsConfig.getInstance(instance)
val sourceConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SOURCE_REGEX)
@@ -160,7 +174,7 @@ private[spark] class MetricsSystem private (
}
}
- def registerSinks() {
+ private def registerSinks() {
val instConfig = metricsConfig.getInstance(instance)
val sinkConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SINK_REGEX)
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
index d7b5f5c40efae..2d25ebd66159f 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -22,7 +22,7 @@ import java.util.Properties
import java.util.concurrent.TimeUnit
import com.codahale.metrics.MetricRegistry
-import com.codahale.metrics.graphite.{Graphite, GraphiteReporter}
+import com.codahale.metrics.graphite.{GraphiteUDP, Graphite, GraphiteReporter}
import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
@@ -38,6 +38,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric
val GRAPHITE_KEY_PERIOD = "period"
val GRAPHITE_KEY_UNIT = "unit"
val GRAPHITE_KEY_PREFIX = "prefix"
+ val GRAPHITE_KEY_PROTOCOL = "protocol"
def propertyToOption(prop: String): Option[String] = Option(property.getProperty(prop))
@@ -66,7 +67,11 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric
MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
- val graphite: Graphite = new Graphite(new InetSocketAddress(host, port))
+ val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match {
+ case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port))
+ case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port))
+ case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p")
+ }
val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry)
.convertDurationsTo(TimeUnit.MILLISECONDS)
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index e0e91724271c8..1745d52c81923 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -17,20 +17,20 @@
package org.apache.spark.network
-import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+private[spark]
trait BlockDataManager {
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- def getBlockData(blockId: String): Option[ManagedBuffer]
+ def getBlockData(blockId: BlockId): ManagedBuffer
/**
* Put the block locally, using the given storage level.
*/
- def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit
+ def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
deleted file mode 100644
index 34acaa563ca58..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.util.EventListener
-
-
-/**
- * Listener callback interface for [[BlockTransferService.fetchBlocks]].
- */
-trait BlockFetchingListener extends EventListener {
-
- /**
- * Called once per successfully fetched block.
- */
- def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit
-
- /**
- * Called upon failures. For each failure, this is called only once (i.e. not once per block).
- */
- def onBlockFetchFailure(exception: Throwable): Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index 84d991fa6808c..dcbda5a8515dd 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -17,13 +17,19 @@
package org.apache.spark.network
-import scala.concurrent.{Await, Future}
-import scala.concurrent.duration.Duration
+import java.io.Closeable
+import java.nio.ByteBuffer
-import org.apache.spark.storage.StorageLevel
+import scala.concurrent.{Promise, Await, Future}
+import scala.concurrent.duration.Duration
+import org.apache.spark.Logging
+import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener}
+import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel}
-abstract class BlockTransferService {
+private[spark]
+abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {
/**
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
@@ -34,7 +40,7 @@ abstract class BlockTransferService {
/**
* Tear down the transfer service.
*/
- def stop(): Unit
+ def close(): Unit
/**
* Port number the service is listening on, available only after [[init]] is invoked.
@@ -50,17 +56,15 @@ abstract class BlockTransferService {
* Fetch a sequence of blocks from a remote node asynchronously,
* available only after [[init]] is invoked.
*
- * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block,
- * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block).
- *
* Note that this API takes a sequence so the implementation can batch requests, and does not
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
- def fetchBlocks(
- hostName: String,
+ override def fetchBlocks(
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit
/**
@@ -69,7 +73,8 @@ abstract class BlockTransferService {
def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ execId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
@@ -78,40 +83,23 @@ abstract class BlockTransferService {
*
* It is also only available after [[init]] is invoked.
*/
- def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = {
+ def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = {
// A monitor for the thread to wait on.
- val lock = new Object
- @volatile var result: Either[ManagedBuffer, Throwable] = null
- fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener {
- override def onBlockFetchFailure(exception: Throwable): Unit = {
- lock.synchronized {
- result = Right(exception)
- lock.notify()
+ val result = Promise[ManagedBuffer]()
+ fetchBlocks(host, port, execId, Array(blockId),
+ new BlockFetchingListener {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+ result.failure(exception)
}
- }
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- lock.synchronized {
- result = Left(data)
- lock.notify()
- }
- }
- })
-
- // Sleep until result is no longer null
- lock.synchronized {
- while (result == null) {
- try {
- lock.wait()
- } catch {
- case e: InterruptedException =>
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ val ret = ByteBuffer.allocate(data.size.toInt)
+ ret.put(data.nioByteBuffer())
+ ret.flip()
+ result.success(new NioManagedBuffer(ret))
}
- }
- }
+ })
- result match {
- case Left(data) => data
- case Right(e) => throw e
- }
+ Await.result(result.future, Duration.Inf)
}
/**
@@ -123,9 +111,10 @@ abstract class BlockTransferService {
def uploadBlockSync(
hostname: String,
port: Int,
- blockId: String,
+ execId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
- Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
+ Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
deleted file mode 100644
index 4c9ca97a2a6b7..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.io._
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-import java.nio.channels.FileChannel.MapMode
-
-import scala.util.Try
-
-import com.google.common.io.ByteStreams
-import io.netty.buffer.{ByteBufInputStream, ByteBuf}
-
-import org.apache.spark.util.{ByteBufferInputStream, Utils}
-
-
-/**
- * This interface provides an immutable view for data in the form of bytes. The implementation
- * should specify how the data is provided:
- *
- * - FileSegmentManagedBuffer: data backed by part of a file
- * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer
- * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf
- */
-sealed abstract class ManagedBuffer {
- // Note that all the methods are defined with parenthesis because their implementations can
- // have side effects (io operations).
-
- /** Number of bytes of the data. */
- def size: Long
-
- /**
- * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
- * returned ByteBuffer should not affect the content of this buffer.
- */
- def nioByteBuffer(): ByteBuffer
-
- /**
- * Exposes this buffer's data as an InputStream. The underlying implementation does not
- * necessarily check for the length of bytes read, so the caller is responsible for making sure
- * it does not go over the limit.
- */
- def inputStream(): InputStream
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a segment in a file
- */
-final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
- extends ManagedBuffer {
-
- /**
- * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
- * Avoid unless there's a good reason not to.
- */
- private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
-
- override def size: Long = length
-
- override def nioByteBuffer(): ByteBuffer = {
- var channel: FileChannel = null
- try {
- channel = new RandomAccessFile(file, "r").getChannel
- // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
- if (length < MIN_MEMORY_MAP_BYTES) {
- val buf = ByteBuffer.allocate(length.toInt)
- channel.read(buf, offset)
- buf.flip()
- buf
- } else {
- channel.map(MapMode.READ_ONLY, offset, length)
- }
- } catch {
- case e: IOException =>
- Try(channel.size).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- } finally {
- if (channel != null) {
- Utils.tryLog(channel.close())
- }
- }
- }
-
- override def inputStream(): InputStream = {
- var is: FileInputStream = null
- try {
- is = new FileInputStream(file)
- is.skip(offset)
- ByteStreams.limit(is, length)
- } catch {
- case e: IOException =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- Try(file.length).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- case e: Throwable =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- throw e
- }
- }
-
- override def toString: String = s"${getClass.getName}($file, $offset, $length)"
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]].
- */
-final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {
-
- override def size: Long = buf.remaining()
-
- override def nioByteBuffer() = buf.duplicate()
-
- override def inputStream() = new ByteBufferInputStream(buf)
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]].
- */
-final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer {
-
- override def size: Long = buf.readableBytes()
-
- override def nioByteBuffer() = buf.nioBuffer()
-
- override def inputStream() = new ByteBufInputStream(buf)
-
- // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it.
- def release(): Unit = buf.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
new file mode 100644
index 0000000000000..b089da8596e2b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.Logging
+import org.apache.spark.network.BlockDataManager
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
+import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
+import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+
+/**
+ * Serves requests to open blocks by simply registering one chunk per block requested.
+ * Handles opening and uploading arbitrary BlockManager blocks.
+ *
+ * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
+ * is equivalent to one Spark-level shuffle block.
+ */
+class NettyBlockRpcServer(
+ serializer: Serializer,
+ blockManager: BlockDataManager)
+ extends RpcHandler with Logging {
+
+ private val streamManager = new OneForOneStreamManager()
+
+ override def receive(
+ client: TransportClient,
+ messageBytes: Array[Byte],
+ responseContext: RpcResponseCallback): Unit = {
+ val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
+ logTrace(s"Received request: $message")
+
+ message match {
+ case openBlocks: OpenBlocks =>
+ val blocks: Seq[ManagedBuffer] =
+ openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
+ val streamId = streamManager.registerStream(blocks.iterator)
+ logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
+ responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
+
+ case uploadBlock: UploadBlock =>
+ // StorageLevel is serialized as bytes using our JavaSerializer.
+ val level: StorageLevel =
+ serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
+ val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
+ blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
+ responseContext.onSuccess(new Array[Byte](0))
+ }
+ }
+
+ override def getStreamManager(): StreamManager = streamManager
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
new file mode 100644
index 0000000000000..3f0950dae1f24
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import scala.collection.JavaConversions._
+import scala.concurrent.{Future, Promise}
+
+import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
+import org.apache.spark.network.server._
+import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.network.shuffle.protocol.UploadBlock
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+/**
+ * A BlockTransferService that uses Netty to fetch a set of blocks at at time.
+ */
+class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int)
+ extends BlockTransferService {
+
+ // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
+ private val serializer = new JavaSerializer(conf)
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+ private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores)
+
+ private[this] var transportContext: TransportContext = _
+ private[this] var server: TransportServer = _
+ private[this] var clientFactory: TransportClientFactory = _
+ private[this] var appId: String = _
+
+ override def init(blockDataManager: BlockDataManager): Unit = {
+ val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
+ val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ if (!authEnabled) {
+ (nettyRpcHandler, None)
+ } else {
+ (new SaslRpcHandler(nettyRpcHandler, securityManager),
+ Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
+ }
+ }
+ transportContext = new TransportContext(transportConf, rpcHandler)
+ clientFactory = transportContext.createClientFactory(bootstrap.toList)
+ server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0))
+ appId = conf.getAppId
+ logInfo("Server created on " + server.getPort)
+ }
+
+ override def fetchBlocks(
+ host: String,
+ port: Int,
+ execId: String,
+ blockIds: Array[String],
+ listener: BlockFetchingListener): Unit = {
+ logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
+ try {
+ val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
+ override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
+ val client = clientFactory.createClient(host, port)
+ new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
+ }
+ }
+
+ val maxRetries = transportConf.maxIORetries()
+ if (maxRetries > 0) {
+ // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+ // a bug in this code. We should remove the if statement once we're sure of the stability.
+ new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
+ } else {
+ blockFetchStarter.createAndStart(blockIds, listener)
+ }
+ } catch {
+ case e: Exception =>
+ logError("Exception while beginning fetchBlocks", e)
+ blockIds.foreach(listener.onBlockFetchFailure(_, e))
+ }
+ }
+
+ override def hostName: String = Utils.localHostName()
+
+ override def port: Int = server.getPort
+
+ override def uploadBlock(
+ hostname: String,
+ port: Int,
+ execId: String,
+ blockId: BlockId,
+ blockData: ManagedBuffer,
+ level: StorageLevel): Future[Unit] = {
+ val result = Promise[Unit]()
+ val client = clientFactory.createClient(hostname, port)
+
+ // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
+ // using our binary protocol.
+ val levelBytes = serializer.newInstance().serialize(level).array()
+
+ // Convert or copy nio buffer into array in order to serialize it.
+ val nioBuffer = blockData.nioByteBuffer()
+ val array = if (nioBuffer.hasArray) {
+ nioBuffer.array()
+ } else {
+ val data = new Array[Byte](nioBuffer.remaining())
+ nioBuffer.get(data)
+ data
+ }
+
+ client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
+ new RpcResponseCallback {
+ override def onSuccess(response: Array[Byte]): Unit = {
+ logTrace(s"Successfully uploaded block $blockId")
+ result.success()
+ }
+ override def onFailure(e: Throwable): Unit = {
+ logError(s"Error while uploading block $blockId", e)
+ result.failure(e)
+ }
+ })
+
+ result.future
+ }
+
+ override def close(): Unit = {
+ server.close()
+ clientFactory.close()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
deleted file mode 100644
index b5870152c5a64..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import org.apache.spark.SparkConf
-
-/**
- * A central location that tracks all the settings we exposed to users.
- */
-private[spark]
-class NettyConfig(conf: SparkConf) {
-
- /** Port the server listens on. Default to a random port. */
- private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0)
-
- /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */
- private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase
-
- /** Connect timeout in secs. Default 60 secs. */
- private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000
-
- /**
- * Percentage of the desired amount of time spent for I/O in the child event loops.
- * Only applicable in nio and epoll.
- */
- private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80)
-
- /** Requested maximum length of the queue of incoming connections. */
- private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt)
-
- /**
- * Receive buffer size (SO_RCVBUF).
- * Note: the optimal size for receive buffer and send buffer should be
- * latency * network_bandwidth.
- * Assuming latency = 1ms, network_bandwidth = 10Gbps
- * buffer size should be ~ 1.25MB
- */
- private[netty] val receiveBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-
- /** Send buffer size (SO_SNDBUF). */
- private[netty] val sendBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala
deleted file mode 100644
index 0d7695072a7b1..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import org.apache.spark.storage.{BlockId, FileSegment}
-
-trait PathResolver {
- /** Get the file segment in which the given block resides. */
- def getBlockLocation(blockId: BlockId): FileSegment
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
new file mode 100644
index 0000000000000..cef203006d685
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.util.{TransportConf, ConfigProvider}
+
+/**
+ * Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor,
+ * Driver, or a standalone shuffle service) into a TransportConf with details on our environment
+ * like the number of cores that are allocated to this JVM.
+ */
+object SparkTransportConf {
+ /**
+ * Specifies an upper bound on the number of Netty threads that Spark requires by default.
+ * In practice, only 2-4 cores should be required to transfer roughly 10 Gb/s, and each core
+ * that we use will have an initial overhead of roughly 32 MB of off-heap memory, which comes
+ * at a premium.
+ *
+ * Thus, this value should still retain maximum throughput and reduce wasted off-heap memory
+ * allocation. It can be overridden by setting the number of serverThreads and clientThreads
+ * manually in Spark's configuration.
+ */
+ private val MAX_DEFAULT_NETTY_THREADS = 8
+
+ /**
+ * Utility for creating a [[TransportConf]] from a [[SparkConf]].
+ * @param numUsableCores if nonzero, this will restrict the server and client threads to only
+ * use the given number of cores, rather than all of the machine's cores.
+ * This restriction will only occur if these properties are not already set.
+ */
+ def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = {
+ val conf = _conf.clone
+
+ // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily
+ // assuming we have all the machine's cores).
+ // NB: Only set if serverThreads/clientThreads not already set.
+ val numThreads = defaultNumThreads(numUsableCores)
+ conf.set("spark.shuffle.io.serverThreads",
+ conf.get("spark.shuffle.io.serverThreads", numThreads.toString))
+ conf.set("spark.shuffle.io.clientThreads",
+ conf.get("spark.shuffle.io.clientThreads", numThreads.toString))
+
+ new TransportConf(new ConfigProvider {
+ override def get(name: String): String = conf.get(name)
+ })
+ }
+
+ /**
+ * Returns the default number of threads for both the Netty client and server thread pools.
+ * If numUsableCores is 0, we will use Runtime get an approximate number of available cores.
+ */
+ private def defaultNumThreads(numUsableCores: Int): Int = {
+ val availableCores =
+ if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
+ math.min(availableCores, MAX_DEFAULT_NETTY_THREADS)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
deleted file mode 100644
index e28219dd7745b..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.util.EventListener
-
-
-trait BlockClientListener extends EventListener {
-
- def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit
-
- def onFetchFailure(blockId: String, errorMsg: String): Unit
-
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
deleted file mode 100644
index 5aea7ba2f3673..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.util.concurrent.TimeoutException
-
-import io.netty.bootstrap.Bootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.handler.codec.LengthFieldBasedFrameDecoder
-import io.netty.handler.codec.string.StringEncoder
-import io.netty.util.CharsetUtil
-
-import org.apache.spark.Logging
-
-/**
- * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]].
- * Use [[BlockFetchingClientFactory]] to instantiate this client.
- *
- * The constructor blocks until a connection is successfully established.
- *
- * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol.
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-@throws[TimeoutException]
-private[spark]
-class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int)
- extends Logging {
-
- private val handler = new BlockFetchingClientHandler
-
- /** Netty Bootstrap for creating the TCP connection. */
- private val bootstrap: Bootstrap = {
- val b = new Bootstrap
- b.group(factory.workerGroup)
- .channel(factory.socketChannelClass)
- // Use pooled buffers to reduce temporary buffer allocation
- .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- // Disable Nagle's Algorithm since we don't want packets to wait
- .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
- .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
- .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs)
-
- b.handler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8))
- // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4
- .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4))
- .addLast("handler", handler)
- }
- })
- b
- }
-
- /** Netty ChannelFuture for the connection. */
- private val cf: ChannelFuture = bootstrap.connect(hostname, port)
- if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) {
- throw new TimeoutException(
- s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)")
- }
-
- /**
- * Ask the remote server for a sequence of blocks, and execute the callback.
- *
- * Note that this is asynchronous and returns immediately. Upstream caller should throttle the
- * rate of fetching; otherwise we could run out of memory.
- *
- * @param blockIds sequence of block ids to fetch.
- * @param listener callback to fire on fetch success / failure.
- */
- def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = {
- // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline.
- // It's also best to limit the number of "flush" calls since it requires system calls.
- // Let's concatenate the string and then call writeAndFlush once.
- // This is also why this implementation might be more efficient than multiple, separate
- // fetch block calls.
- var startTime: Long = 0
- logTrace {
- startTime = System.nanoTime
- s"Sending request $blockIds to $hostname:$port"
- }
-
- blockIds.foreach { blockId =>
- handler.addRequest(blockId, listener)
- }
-
- val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n")
- writeFuture.addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (future.isSuccess) {
- logTrace {
- val timeTaken = (System.nanoTime - startTime).toDouble / 1000000
- s"Sending request $blockIds to $hostname:$port took $timeTaken ms"
- }
- } else {
- // Fail all blocks.
- val errorMsg =
- s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
- logError(errorMsg, future.cause)
- blockIds.foreach { blockId =>
- listener.onFetchFailure(blockId, errorMsg)
- handler.removeRequest(blockId)
- }
- }
- }
- })
- }
-
- def waitForClose(): Unit = {
- cf.channel().closeFuture().sync()
- }
-
- def close(): Unit = cf.channel().close()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
deleted file mode 100644
index 2b28402c52b49..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.nio.NioSocketChannel
-import io.netty.channel.socket.oio.OioSocketChannel
-import io.netty.channel.{EventLoopGroup, Channel}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.util.Utils
-
-/**
- * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses
- * the worker thread pool for Netty.
- *
- * Concurrency: createClient is safe to be called from multiple threads concurrently.
- */
-private[spark]
-class BlockFetchingClientFactory(val conf: NettyConfig) {
-
- def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))
-
- /** A thread factory so the threads are named (for debugging). */
- val threadFactory = Utils.namedThreadFactory("spark-shuffle-client")
-
- /** The following two are instantiated by the [[init]] method, depending ioMode. */
- var socketChannelClass: Class[_ <: Channel] = _
- var workerGroup: EventLoopGroup = _
-
- init()
-
- /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */
- private def init(): Unit = {
- def initOio(): Unit = {
- socketChannelClass = classOf[OioSocketChannel]
- workerGroup = new OioEventLoopGroup(0, threadFactory)
- }
- def initNio(): Unit = {
- socketChannelClass = classOf[NioSocketChannel]
- workerGroup = new NioEventLoopGroup(0, threadFactory)
- }
- def initEpoll(): Unit = {
- socketChannelClass = classOf[EpollSocketChannel]
- workerGroup = new EpollEventLoopGroup(0, threadFactory)
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
- }
-
- /**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
- *
- * This blocks until a connection is successfully established.
- *
- * Concurrency: This method is safe to call from multiple threads.
- */
- def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = {
- new BlockFetchingClient(this, remoteHost, remotePort)
- }
-
- def stop(): Unit = {
- if (workerGroup != null) {
- workerGroup.shutdownGracefully()
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
deleted file mode 100644
index 83265b164299d..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
-
-import org.apache.spark.Logging
-
-
-/**
- * Handler that processes server responses. It uses the protocol documented in
- * [[org.apache.spark.network.netty.server.BlockServer]].
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-private[client]
-class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging {
-
- /** Tracks the list of outstanding requests and their listeners on success/failure. */
- private val outstandingRequests = java.util.Collections.synchronizedMap {
- new java.util.HashMap[String, BlockClientListener]
- }
-
- def addRequest(blockId: String, listener: BlockClientListener): Unit = {
- outstandingRequests.put(blockId, listener)
- }
-
- def removeRequest(blockId: String): Unit = {
- outstandingRequests.remove(blockId)
- }
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}"
- logError(errorMsg, cause)
-
- // Fire the failure callback for all outstanding blocks
- outstandingRequests.synchronized {
- val iter = outstandingRequests.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- entry.getValue.onFetchFailure(entry.getKey, errorMsg)
- }
- outstandingRequests.clear()
- }
-
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) {
- val totalLen = in.readInt()
- val blockIdLen = in.readInt()
- val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
- in.readBytes(blockIdBytes)
- val blockId = new String(blockIdBytes)
- val blockSize = totalLen - math.abs(blockIdLen) - 4
-
- def server = ctx.channel.remoteAddress.toString
-
- // blockIdLen is negative when it is an error message.
- if (blockIdLen < 0) {
- val errorMessageBytes = new Array[Byte](blockSize)
- in.readBytes(errorMessageBytes)
- val errorMsg = new String(errorMessageBytes)
- logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchFailure(blockId, errorMsg)
- }
- } else {
- logTrace(s"Received block $blockId ($blockSize B) from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in))
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
deleted file mode 100644
index 9740ee64d1f2d..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-/**
- * A simple iterator that lazily initializes the underlying iterator.
- *
- * The use case is that sometimes we might have many iterators open at the same time, and each of
- * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer).
- * This could lead to too many buffers open. If this iterator is used, we lazily initialize those
- * buffers.
- */
-private[spark]
-class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] {
-
- lazy val proxy = createIterator
-
- override def hasNext: Boolean = {
- val gotNext = proxy.hasNext
- if (!gotNext) {
- close()
- }
- gotNext
- }
-
- override def next(): Any = proxy.next()
-
- def close(): Unit = Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
deleted file mode 100644
index ea1abf5eccc26..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.io.InputStream
-import java.nio.ByteBuffer
-
-import io.netty.buffer.{ByteBuf, ByteBufInputStream}
-
-
-/**
- * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty.
- * This is a Scala value class.
- *
- * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of
- * reference by the retain method and release method.
- */
-private[spark]
-class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal {
-
- /** Return the nio ByteBuffer view of the underlying buffer. */
- def byteBuffer(): ByteBuffer = underlying.nioBuffer
-
- /** Creates a new input stream that starts from the current position of the buffer. */
- def inputStream(): InputStream = new ByteBufInputStream(underlying)
-
- /** Increment the reference counter by one. */
- def retain(): Unit = underlying.retain()
-
- /** Decrement the reference counter by one and release the buffer if the ref count is 0. */
- def release(): Unit = underlying.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
deleted file mode 100644
index 162e9cc6828d4..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-/**
- * Header describing a block. This is used only in the server pipeline.
- *
- * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it.
- *
- * @param blockSize length of the block content, excluding the length itself.
- * If positive, this is the header for a block (not part of the header).
- * If negative, this is the header and content for an error message.
- * @param blockId block id
- * @param error some error message from reading the block
- */
-private[server]
-class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None)
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
deleted file mode 100644
index 8e4dda4ef8595..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.ChannelHandlerContext
-import io.netty.handler.codec.MessageToByteEncoder
-
-/**
- * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol.
- */
-private[server]
-class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] {
- override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = {
- // message = message length (4 bytes) + block id length (4 bytes) + block id + block data
- // message length = block id length (4 bytes) + size of block id + size of block data
- val blockIdBytes = msg.blockId.getBytes
- msg.error match {
- case Some(errorMsg) =>
- val errorBytes = errorMsg.getBytes
- out.writeInt(4 + blockIdBytes.length + errorBytes.size)
- out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors
- out.writeBytes(blockIdBytes) // next is blockId itself
- out.writeBytes(errorBytes) // error message
- case None =>
- out.writeInt(4 + blockIdBytes.length + msg.blockSize)
- out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length
- out.writeBytes(blockIdBytes) // next is blockId itself
- // msg of size blockSize will be written by ServerHandler
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
deleted file mode 100644
index 7b2f9a8d4dfd0..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.net.InetSocketAddress
-
-import io.netty.bootstrap.ServerBootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.socket.nio.NioServerSocketChannel
-import io.netty.channel.socket.oio.OioServerSocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-import io.netty.util.CharsetUtil
-
-import org.apache.spark.{Logging, SparkConf}
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.storage.BlockDataProvider
-import org.apache.spark.util.Utils
-
-
-/**
- * Server for serving Spark data blocks.
- * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]].
- *
- * Protocol for requesting blocks (client to server):
- * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n"
- *
- * Protocol for sending blocks (server to client):
- * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data.
- *
- * frame-length should not include the length of itself.
- * If block-id-length is negative, then this is an error message rather than block-data. The real
- * length is the absolute value of the frame-length.
- *
- */
-private[spark]
-class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging {
-
- def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = {
- this(new NettyConfig(sparkConf), dataProvider)
- }
-
- def port: Int = _port
-
- def hostName: String = _hostName
-
- private var _port: Int = conf.serverPort
- private var _hostName: String = ""
- private var bootstrap: ServerBootstrap = _
- private var channelFuture: ChannelFuture = _
-
- init()
-
- /** Initialize the server. */
- private def init(): Unit = {
- bootstrap = new ServerBootstrap
- val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss")
- val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker")
-
- // Use only one thread to accept connections, and 2 * num_cores for worker.
- def initNio(): Unit = {
- val bossGroup = new NioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new NioEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel])
- }
- def initOio(): Unit = {
- val bossGroup = new OioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new OioEventLoopGroup(0, workerThreadFactory)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel])
- }
- def initEpoll(): Unit = {
- val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel])
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
-
- // Use pooled buffers to reduce temporary buffer allocation
- bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
-
- // Various (advanced) user-configured settings.
- conf.backLog.foreach { backLog =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog)
- }
- conf.receiveBuf.foreach { receiveBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf)
- }
- conf.sendBuf.foreach { sendBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf)
- }
-
- bootstrap.childHandler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
- })
-
- channelFuture = bootstrap.bind(new InetSocketAddress(_port))
- channelFuture.sync()
-
- val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
- _port = addr.getPort
- _hostName = addr.getHostName
- }
-
- /** Shutdown the server. */
- def stop(): Unit = {
- if (channelFuture != null) {
- channelFuture.channel().close().awaitUninterruptibly()
- channelFuture = null
- }
- if (bootstrap != null && bootstrap.group() != null) {
- bootstrap.group().shutdownGracefully()
- }
- if (bootstrap != null && bootstrap.childGroup() != null) {
- bootstrap.childGroup().shutdownGracefully()
- }
- bootstrap = null
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
deleted file mode 100644
index cc70bd0c5c477..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import io.netty.channel.ChannelInitializer
-import io.netty.channel.socket.SocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-import io.netty.util.CharsetUtil
-import org.apache.spark.storage.BlockDataProvider
-
-
-/** Channel initializer that sets up the pipeline for the BlockServer. */
-private[netty]
-class BlockServerChannelInitializer(dataProvider: BlockDataProvider)
- extends ChannelInitializer[SocketChannel] {
-
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
deleted file mode 100644
index 40dd5e5d1a2ac..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.io.FileInputStream
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-
-import io.netty.buffer.Unpooled
-import io.netty.channel._
-
-import org.apache.spark.Logging
-import org.apache.spark.storage.{FileSegment, BlockDataProvider}
-
-
-/**
- * A handler that processes requests from clients and writes block data back.
- *
- * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first
- * so channelRead0 is called once per line (i.e. per block id).
- */
-private[server]
-class BlockServerHandler(dataProvider: BlockDataProvider)
- extends SimpleChannelInboundHandler[String] with Logging {
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause)
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = {
- def client = ctx.channel.remoteAddress.toString
-
- // A helper function to send error message back to the client.
- def respondWithError(error: String): Unit = {
- ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener(
- new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (!future.isSuccess) {
- // TODO: Maybe log the success case as well.
- logError(s"Error sending error back to $client", future.cause)
- ctx.close()
- }
- }
- }
- )
- }
-
- def writeFileSegment(segment: FileSegment): Unit = {
- // Send error message back if the block is too large. Even though we are capable of sending
- // large (2G+) blocks, the receiving end cannot handle it so let's fail fast.
- // Once we fixed the receiving end to be able to process large blocks, this should be removed.
- // Also make sure we update BlockHeaderEncoder to support length > 2G.
-
- // See [[BlockHeaderEncoder]] for the way length is encoded.
- if (segment.length + blockId.length + 4 > Int.MaxValue) {
- respondWithError(s"Block $blockId size ($segment.length) greater than 2G")
- return
- }
-
- var fileChannel: FileChannel = null
- try {
- fileChannel = new FileInputStream(segment.file).getChannel
- } catch {
- case e: Exception =>
- logError(
- s"Error opening channel for $blockId in ${segment.file} for request from $client", e)
- respondWithError(e.getMessage)
- }
-
- // Found the block. Send it back.
- if (fileChannel != null) {
- // Write the header and block data. In the case of failures, the listener on the block data
- // write should close the connection.
- ctx.write(new BlockHeader(segment.length.toInt, blockId))
-
- val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length)
- ctx.writeAndFlush(region).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${segment.length} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
- }
-
- def writeByteBuffer(buf: ByteBuffer): Unit = {
- ctx.write(new BlockHeader(buf.remaining, blockId))
- ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
-
- logTrace(s"Received request from $client to fetch block $blockId")
-
- var blockData: Either[FileSegment, ByteBuffer] = null
-
- // First make sure we can find the block. If not, send error back to the user.
- try {
- blockData = dataProvider.getBlockData(blockId)
- } catch {
- case e: Exception =>
- logError(s"Error opening block $blockId for request from $client", e)
- respondWithError(e.getMessage)
- return
- }
-
- blockData match {
- case Left(segment) => writeFileSegment(segment)
- case Right(buf) => writeByteBuffer(buf)
- }
-
- } // end of channelRead0
-}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 4f6f5e235811d..c2d9578be7ebb 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -23,12 +23,13 @@ import java.nio.channels._
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.LinkedList
-import org.apache.spark._
-
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
+import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
+
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index bda4bf50932c3..ee22c6656e69e 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -18,20 +18,24 @@
package org.apache.spark.network.nio
import java.io.IOException
+import java.lang.ref.WeakReference
import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
-import java.util.{Timer, TimerTask}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue}
import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps
+import com.google.common.base.Charsets.UTF_8
+import io.netty.util.{Timeout, TimerTask, HashedWheelTimer}
+
import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
import org.apache.spark.util.Utils
import scala.util.Try
@@ -74,13 +78,27 @@ private[nio] class ConnectionManager(
}
private val selector = SelectorProvider.provider.openSelector()
- private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
-
- private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
+ private val ackTimeoutMonitor =
+ new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor"))
+
+ private val ackTimeout =
+ conf.getInt("spark.core.connection.ack.wait.timeout", conf.getInt("spark.network.timeout", 120))
+
+ // Get the thread counts from the Spark Configuration.
+ //
+ // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value,
+ // we only query for the minimum value because we are using LinkedBlockingDeque.
+ //
+ // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is
+ // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min"
+ // parameter is necessary.
+ private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20)
+ private val ioThreadCount = conf.getInt("spark.core.connection.io.threads.min", 4)
+ private val connectThreadCount = conf.getInt("spark.core.connection.connect.threads.min", 1)
private val handleMessageExecutor = new ThreadPoolExecutor(
- conf.getInt("spark.core.connection.handler.threads.min", 20),
- conf.getInt("spark.core.connection.handler.threads.max", 60),
+ handlerThreadCount,
+ handlerThreadCount,
conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
Utils.namedThreadFactory("handle-message-executor")) {
@@ -91,12 +109,11 @@ private[nio] class ConnectionManager(
logError("Error in handleMessageExecutor is not handled properly", t)
}
}
-
}
private val handleReadWriteExecutor = new ThreadPoolExecutor(
- conf.getInt("spark.core.connection.io.threads.min", 4),
- conf.getInt("spark.core.connection.io.threads.max", 32),
+ ioThreadCount,
+ ioThreadCount,
conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
Utils.namedThreadFactory("handle-read-write-executor")) {
@@ -107,14 +124,13 @@ private[nio] class ConnectionManager(
logError("Error in handleReadWriteExecutor is not handled properly", t)
}
}
-
}
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks :
// which should be executed asap
private val handleConnectExecutor = new ThreadPoolExecutor(
- conf.getInt("spark.core.connection.connect.threads.min", 1),
- conf.getInt("spark.core.connection.connect.threads.max", 8),
+ connectThreadCount,
+ connectThreadCount,
conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
Utils.namedThreadFactory("handle-connect-executor")) {
@@ -125,7 +141,6 @@ private[nio] class ConnectionManager(
logError("Error in handleConnectExecutor is not handled properly", t)
}
}
-
}
private val serverChannel = ServerSocketChannel.open()
@@ -136,7 +151,10 @@ private[nio] class ConnectionManager(
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
with SynchronizedMap[ConnectionManagerId, SendingConnection]
- private val messageStatuses = new HashMap[Int, MessageStatus]
+ // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this
+ // map when messages are sent and are removed when acknowledgement messages are received or when
+ // acknowledgement timeouts expire
+ private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus]
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
@@ -156,7 +174,7 @@ private[nio] class ConnectionManager(
serverChannel.socket.bind(new InetSocketAddress(port))
(serverChannel, serverChannel.socket.getLocalPort)
}
- Utils.startServiceOnPort[ServerSocketChannel](port, startService, name)
+ Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name)
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
@@ -166,14 +184,16 @@ private[nio] class ConnectionManager(
// to be able to track asynchronous messages
private val idCount: AtomicInteger = new AtomicInteger(1)
+ private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+ private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
selectorThread.setDaemon(true)
+ // start this thread last, since it invokes run(), which accesses members above
selectorThread.start()
- private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
-
private def triggerWrite(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null)
if (conn == null) return
@@ -214,7 +234,6 @@ private[nio] class ConnectionManager(
} )
}
- private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
private def triggerRead(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null)
@@ -598,7 +617,7 @@ private[nio] class ConnectionManager(
} else {
var replyToken : Array[Byte] = null
try {
- replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+ replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken)
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
@@ -632,7 +651,7 @@ private[nio] class ConnectionManager(
connection.synchronized {
if (connection.sparkSaslServer == null) {
logDebug("Creating sasl Server")
- connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
}
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@@ -776,7 +795,7 @@ private[nio] class ConnectionManager(
if (!conn.isSaslComplete()) {
conn.synchronized {
if (conn.sparkSaslClient == null) {
- conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()
@@ -896,22 +915,41 @@ private[nio] class ConnectionManager(
: Future[Message] = {
val promise = Promise[Message]()
- val timeoutTask = new TimerTask {
- override def run(): Unit = {
+ // It's important that the TimerTask doesn't capture a reference to `message`, which can cause
+ // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time
+ // at which they would originally be scheduled to run. Therefore, extract the message id
+ // from outside of the TimerTask closure (see SPARK-4393 for more context).
+ val messageId = message.id
+ // Keep a weak reference to the promise so that the completed promise may be garbage-collected
+ val promiseReference = new WeakReference(promise)
+ val timeoutTask: TimerTask = new TimerTask {
+ override def run(timeout: Timeout): Unit = {
messageStatuses.synchronized {
- messageStatuses.remove(message.id).foreach ( s => {
+ messageStatuses.remove(messageId).foreach { s =>
val e = new IOException("sendMessageReliably failed because ack " +
s"was not received within $ackTimeout sec")
- if (!promise.tryFailure(e)) {
- logWarning("Ignore error because promise is completed", e)
+ val p = promiseReference.get
+ if (p != null) {
+ // Attempt to fail the promise with a Timeout exception
+ if (!p.tryFailure(e)) {
+ // If we reach here, then someone else has already signalled success or failure
+ // on this promise, so log a warning:
+ logError("Ignore error because promise is completed", e)
+ }
+ } else {
+ // The WeakReference was empty, which should never happen because
+ // sendMessageReliably's caller should have a strong reference to promise.future;
+ logError("Promise was garbage collected; this should never happen!", e)
}
- })
+ }
}
}
}
+ val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS)
+
val status = new MessageStatus(message, connectionManagerId, s => {
- timeoutTask.cancel()
+ timeoutTaskHandle.cancel()
s match {
case scala.util.Failure(e) =>
// Indicates a failure where we either never sent or never got ACK'd
@@ -923,7 +961,7 @@ private[nio] class ConnectionManager(
val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head
val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit())
errorMsgByteBuf.get(errorMsgBytes)
- val errorMsg = new String(errorMsgBytes, "utf-8")
+ val errorMsg = new String(errorMsgBytes, UTF_8)
val e = new IOException(
s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg")
if (!promise.tryFailure(e)) {
@@ -940,7 +978,6 @@ private[nio] class ConnectionManager(
messageStatuses += ((message.id, status))
}
- ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
sendMessage(connectionManagerId, message)
promise.future
}
@@ -950,7 +987,7 @@ private[nio] class ConnectionManager(
}
def stop() {
- ackTimeoutMonitor.cancel()
+ ackTimeoutMonitor.stop()
selectorThread.interrupt()
selectorThread.join()
selector.close()
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
index 3ad04591da658..fb4a979b824c3 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
@@ -22,6 +22,8 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
+import com.google.common.base.Charsets.UTF_8
+
import org.apache.spark.util.Utils
private[nio] abstract class Message(val typ: Long, val id: Int) {
@@ -92,7 +94,7 @@ private[nio] object Message {
*/
def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = {
val exceptionString = Utils.exceptionString(exception)
- val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8"))
+ val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8))
val errorMessage = createBufferMessage(serializedExceptionString, ackId)
errorMessage.hasError = true
errorMessage
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index 5add4fc433fb3..b2aec160635c7 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -19,12 +19,14 @@ package org.apache.spark.network.nio
import java.nio.ByteBuffer
-import scala.concurrent.Future
-
-import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+
+import scala.concurrent.Future
/**
@@ -71,20 +73,21 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
/**
* Tear down the transfer service.
*/
- override def stop(): Unit = {
+ override def close(): Unit = {
if (cm != null) {
cm.stop()
}
}
override def fetchBlocks(
- hostName: String,
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
checkInit()
- val cmId = new ConnectionManagerId(hostName, port)
+ val cmId = new ConnectionManagerId(host, port)
val blockMessageArray = new BlockMessageArray(blockIds.map { blockId =>
BlockMessage.fromGetBlock(GetBlock(BlockId(blockId)))
})
@@ -96,21 +99,33 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- listener.onBlockFetchFailure(
- new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId"))
- } else {
- val blockId = blockMessage.getId
- val networkSize = blockMessage.getData.limit()
- listener.onBlockFetchSuccess(
- blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData))
+ // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty.
+ if (blockMessageArray.isEmpty) {
+ blockIds.foreach { id =>
+ listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId"))
+ }
+ } else {
+ for (blockMessage: BlockMessage <- blockMessageArray) {
+ val msgType = blockMessage.getType
+ if (msgType != BlockMessage.TYPE_GOT_BLOCK) {
+ if (blockMessage.getId != null) {
+ listener.onBlockFetchFailure(blockMessage.getId.toString,
+ new SparkException(s"Unexpected message $msgType received from $cmId"))
+ }
+ } else {
+ val blockId = blockMessage.getId
+ val networkSize = blockMessage.getData.limit()
+ listener.onBlockFetchSuccess(
+ blockId.toString, new NioManagedBuffer(blockMessage.getData))
+ }
}
}
}(cm.futureExecContext)
future.onFailure { case exception =>
- listener.onBlockFetchFailure(exception)
+ blockIds.foreach { blockId =>
+ listener.onBlockFetchFailure(blockId, exception)
+ }
}(cm.futureExecContext)
}
@@ -122,12 +137,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
override def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ execId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel)
: Future[Unit] = {
checkInit()
- val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level)
+ val msg = PutBlock(blockId, blockData.nioByteBuffer(), level)
val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg))
val remoteCmId = new ConnectionManagerId(hostName, port)
val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage)
@@ -149,10 +165,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
- case e: Exception => {
+ case e: Exception =>
logError("Exception handling buffer message", e)
Some(Message.createErrorMessage(e, msg.id))
- }
}
case otherMessage: Any =>
@@ -167,13 +182,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
case BlockMessage.TYPE_PUT_BLOCK =>
val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
logDebug("Received [" + msg + "]")
- putBlock(msg.id.toString, msg.data, msg.level)
+ putBlock(msg.id, msg.data, msg.level)
None
case BlockMessage.TYPE_GET_BLOCK =>
val msg = new GetBlock(blockMessage.getId)
logDebug("Received [" + msg + "]")
- val buffer = getBlock(msg.id.toString)
+ val buffer = getBlock(msg.id)
if (buffer == null) {
return None
}
@@ -183,20 +198,20 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
}
}
- private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes)
- blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level)
+ blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level)
logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " with data size: " + bytes.limit)
}
- private def getBlock(blockId: String): ByteBuffer = {
+ private def getBlock(blockId: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + blockId + " started from " + startTimeMs)
- val buffer = blockDataManager.getBlockData(blockId).orNull
+ val buffer = blockDataManager.getBlockData(blockId)
logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
- if (buffer == null) null else buffer.nioByteBuffer()
+ buffer.nioByteBuffer()
}
}
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index e2fc9c649925e..b6249b492150a 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -27,8 +27,7 @@ package org.apache
* contains operations available only on RDDs of Doubles; and
* [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that can
* be saved as SequenceFiles. These operations are automatically available on any RDD of the right
- * type (e.g. RDD[(Int, Int)] through implicit conversions when you
- * `import org.apache.spark.SparkContext._`.
+ * type (e.g. RDD[(Int, Int)] through implicit conversions.
*
* Java programmers should reference the [[org.apache.spark.api.java]] package
* for Spark programming APIs in Java.
@@ -44,5 +43,5 @@ package org.apache
package object spark {
// For package docs only
- val SPARK_VERSION = "1.2.0-SNAPSHOT"
+ val SPARK_VERSION = "1.3.0-SNAPSHOT"
}
diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
index 3155dfe165664..637492a97551b 100644
--- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.NormalDistribution
/**
* An ApproximateEvaluator for counts.
@@ -46,7 +46,8 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
val mean = (sum + 1 - p) / p
val variance = (sum + 1) * (1 - p) / (p * p)
val stdev = math.sqrt(variance)
- val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val confFactor = new NormalDistribution().
+ inverseCumulativeProbability(1 - (1 - confidence) / 2)
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
new BoundedDouble(mean, confidence, low, high)
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
index 8bb78123e3c9c..3ef3cc219dec6 100644
--- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
@@ -24,7 +24,7 @@ import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.NormalDistribution
import org.apache.spark.util.collection.OpenHashMap
@@ -55,7 +55,8 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf
new HashMap[T, BoundedDouble]
} else {
val p = outputsMerged.toDouble / totalOutputs
- val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val confFactor = new NormalDistribution().
+ inverseCumulativeProbability(1 - (1 - confidence) / 2)
val result = new JHashMap[T, BoundedDouble](sums.size)
sums.foreach { case (key, sum) =>
val mean = (sum + 1 - p) / p
diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
index d24959cba8727..787a21a61fdcf 100644
--- a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution}
import org.apache.spark.util.StatCounter
@@ -45,9 +45,10 @@ private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double)
val stdev = math.sqrt(counter.sampleVariance / counter.count)
val confFactor = {
if (counter.count > 100) {
- Probability.normalInverse(1 - (1 - confidence) / 2)
+ new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
- Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ val degreesOfFreedom = (counter.count - 1).toInt
+ new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
}
val low = mean - confFactor * stdev
diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
index 92915ee66d29f..828bf96c2c0bd 100644
--- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
+++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution}
/**
* A utility class for caching Student's T distribution values for a given confidence level
@@ -25,8 +25,10 @@ import cern.jet.stat.Probability
* confidence intervals for many keys.
*/
private[spark] class StudentTCacher(confidence: Double) {
+
val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
- val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
+
+ val normalApprox = new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
def get(sampleSize: Long): Double = {
@@ -35,7 +37,8 @@ private[spark] class StudentTCacher(confidence: Double) {
} else {
val size = sampleSize.toInt
if (cache(size) < 0) {
- cache(size) = Probability.studentTInverse(1 - confidence, size - 1)
+ val tDist = new TDistribution(size - 1)
+ cache(size) = tDist.inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
cache(size)
}
diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
index d5336284571d2..1753c2561b678 100644
--- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution}
import org.apache.spark.util.StatCounter
@@ -55,9 +55,10 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
val sumStdev = math.sqrt(sumVar)
val confFactor = {
if (counter.count > 100) {
- Probability.normalInverse(1 - (1 - confidence) / 2)
+ new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
- Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ val degreesOfFreedom = (counter.count - 1).toInt
+ new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
}
val low = sumEstimate - confFactor * sumStdev
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index 9f9f10b7ebc3a..646df283ac069 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -27,7 +27,6 @@ import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
/**
* A set of asynchronous RDD actions available through an implicit conversion.
- * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
new file mode 100644
index 0000000000000..1f755db485812
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.hadoop.conf.{ Configurable, Configuration }
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapreduce._
+import org.apache.spark.input.StreamFileInputFormat
+import org.apache.spark.{ Partition, SparkContext }
+
+private[spark] class BinaryFileRDD[T](
+ sc: SparkContext,
+ inputFormatClass: Class[_ <: StreamFileInputFormat[T]],
+ keyClass: Class[String],
+ valueClass: Class[T],
+ @transient conf: Configuration,
+ minPartitions: Int)
+ extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) {
+
+ override def getPartitions: Array[Partition] = {
+ val inputFormat = inputFormatClass.newInstance
+ inputFormat match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val jobContext = newJobContext(conf, jobId)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
+ val rawSplits = inputFormat.getSplits(jobContext).toArray
+ val result = new Array[Partition](rawSplits.size)
+ for (i <- 0 until rawSplits.size) {
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ }
+ result
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 2673ec22509e9..fffa1911f5bc2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -84,5 +84,9 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds
"Attempted to use %s after its blocks have been removed!".format(toString))
}
}
+
+ protected def getBlockIdLocations(): Map[BlockId, Seq[String]] = {
+ locations_
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
index 4908711d17db7..1cbd684224b7c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.util.Utils
private[spark]
class CartesianPartition(
@@ -36,7 +37,7 @@ class CartesianPartition(
override val index: Int = idx
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
s1 = rdd1.partitions(s1Index)
s2 = rdd2.partitions(s2Index)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 7ba1182f0ed27..1c13e2c372845 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -95,7 +95,8 @@ private[spark] object CheckpointRDD extends Logging {
val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
- val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
+ val tempOutputPath =
+ new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber)
if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
@@ -119,7 +120,7 @@ private[spark] object CheckpointRDD extends Logging {
logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
- + ctx.attemptId + " and final output path does not exist")
+ + ctx.attemptNumber + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index fabb882cdd4b3..07398a6fa62f6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
+import org.apache.spark.util.Utils
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle
@@ -39,7 +40,7 @@ private[spark] case class NarrowCoGroupSplitDep(
) extends CoGroupSplitDep {
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
split = rdd.partitions(splitIndex)
oos.defaultWriteObject()
@@ -59,7 +60,7 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
*
- * Note: This is an internal API. We recommend users use RDD.coGroup(...) instead of
+ * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of
* instantiating this directly.
* @param rdds parent RDDs.
@@ -69,8 +70,8 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) {
- // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs).
- // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner.
+ // For example, `(k, a) cogroup (k, b)` produces k -> Array(ArrayBuffer as, ArrayBuffer bs).
+ // Each ArrayBuffer is represented as a CoGroup, and the resulting Array as a CoGroupCombiner.
// CoGroupValue is the intermediate state of each value before being merged in compute.
private type CoGroup = CompactBuffer[Any]
private type CoGroupValue = (Any, Int) // Int is dependency number
@@ -158,8 +159,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((it, depNum) <- rddIterators) {
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
}
- context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled)
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 11ebafbf6d457..b073eba8a1574 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -25,6 +25,7 @@ import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.util.Utils
/**
* Class that captures a coalesced RDD by essentially keeping track of parent partitions
@@ -34,15 +35,14 @@ import org.apache.spark._
* @param preferredLocation the preferred location for this partition
*/
private[spark] case class CoalescedRDDPartition(
- index: Int,
- @transient rdd: RDD[_],
- parentsIndices: Array[Int],
- @transient preferredLocation: String = ""
- ) extends Partition {
+ index: Int,
+ @transient rdd: RDD[_],
+ parentsIndices: Array[Int],
+ @transient preferredLocation: Option[String] = None) extends Partition {
var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent partition at the time of task serialization
parents = parentsIndices.map(rdd.partitions(_))
oos.defaultWriteObject()
@@ -54,9 +54,10 @@ private[spark] case class CoalescedRDDPartition(
* @return locality of this coalesced partition between 0 and 1
*/
def localFraction: Double = {
- val loc = parents.count(p =>
- rdd.context.getPreferredLocs(rdd, p.index).map(tl => tl.host).contains(preferredLocation))
-
+ val loc = parents.count { p =>
+ val parentPreferredLocations = rdd.context.getPreferredLocs(rdd, p.index).map(_.host)
+ preferredLocation.exists(parentPreferredLocations.contains)
+ }
if (parents.size == 0) 0.0 else (loc.toDouble / parents.size.toDouble)
}
}
@@ -72,9 +73,9 @@ private[spark] case class CoalescedRDDPartition(
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
private[spark] class CoalescedRDD[T: ClassTag](
- @transient var prev: RDD[T],
- maxPartitions: Int,
- balanceSlack: Double = 0.10)
+ @transient var prev: RDD[T],
+ maxPartitions: Int,
+ balanceSlack: Double = 0.10)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
override def getPartitions: Array[Partition] = {
@@ -112,7 +113,7 @@ private[spark] class CoalescedRDD[T: ClassTag](
* @return the machine most preferred by split
*/
override def getPreferredLocations(partition: Partition): Seq[String] = {
- List(partition.asInstanceOf[CoalescedRDDPartition].preferredLocation)
+ partition.asInstanceOf[CoalescedRDDPartition].preferredLocation.toSeq
}
}
@@ -146,7 +147,7 @@ private[spark] class CoalescedRDD[T: ClassTag](
*
*/
-private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
+private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size
def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean =
@@ -340,8 +341,14 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc
}
}
-private[spark] case class PartitionGroup(prefLoc: String = "") {
+private case class PartitionGroup(prefLoc: Option[String] = None) {
var arr = mutable.ArrayBuffer[Partition]()
-
def size = arr.size
}
+
+private object PartitionGroup {
+ def apply(prefLoc: String): PartitionGroup = {
+ require(prefLoc != "", "Preferred location must not be empty")
+ PartitionGroup(Some(prefLoc))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index e0494ee39657c..e66f83bb34e30 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -27,7 +27,6 @@ import org.apache.spark.util.StatCounter
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
- * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
/** Add up the elements in this RDD. */
diff --git a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
deleted file mode 100644
index 9e41b3d1e2d4f..0000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.{Partition, TaskContext}
-
-private[spark] class FilteredRDD[T: ClassTag](
- prev: RDD[T],
- f: T => Boolean)
- extends RDD[T](prev) {
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
-
- override def compute(split: Partition, context: TaskContext) =
- firstParent[T].iterator(split, context).filter(f)
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
deleted file mode 100644
index d8f87d4e3690e..0000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.{Partition, TaskContext}
-
-private[spark]
-class FlatMappedRDD[U: ClassTag, T: ClassTag](
- prev: RDD[T],
- f: T => TraversableOnce[U])
- extends RDD[U](prev) {
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override def compute(split: Partition, context: TaskContext) =
- firstParent[T].iterator(split, context).flatMap(f)
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala
deleted file mode 100644
index 7c9023f62d3b6..0000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/FlatMappedValuesRDD.scala
+++ /dev/null
@@ -1,35 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{Partition, TaskContext}
-
-private[spark]
-class FlatMappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => TraversableOnce[U])
- extends RDD[(K, U)](prev) {
-
- override def getPartitions = firstParent[Product2[K, V]].partitions
-
- override val partitioner = firstParent[Product2[K, V]].partitioner
-
- override def compute(split: Partition, context: TaskContext) = {
- firstParent[Product2[K, V]].iterator(split, context).flatMap { case Product2(k, v) =>
- f(v).map(x => (k, x))
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
deleted file mode 100644
index f6463fa715a71..0000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
+++ /dev/null
@@ -1,31 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.{Partition, TaskContext}
-
-private[spark] class GlommedRDD[T: ClassTag](prev: RDD[T])
- extends RDD[Array[T]](prev) {
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override def compute(split: Partition, context: TaskContext) =
- Array(firstParent[T].iterator(split, context).toArray).iterator
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 775141775e06c..89adddcf0ac36 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -35,17 +35,18 @@ import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.mapred.JobID
import org.apache.hadoop.mapred.TaskAttemptID
import org.apache.hadoop.mapred.TaskID
+import org.apache.hadoop.mapred.lib.CombineFileSplit
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark.executor.DataReadMethod
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.{NextIterator, Utils}
import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
-
+import org.apache.spark.storage.StorageLevel
/**
* A Spark split class that wraps around a Hadoop InputSplit.
@@ -132,7 +133,7 @@ class HadoopRDD[K, V](
// used to build JobTracker ID
private val createTime = new Date()
- private val shouldCloneJobConf = sc.conf.get("spark.hadoop.cloneConf", "false").toBoolean
+ private val shouldCloneJobConf = sc.conf.getBoolean("spark.hadoop.cloneConf", false)
// Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads.
protected def getJobConf(): JobConf = {
@@ -212,11 +213,26 @@ class HadoopRDD[K, V](
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
val jobConf = getJobConf()
+
+ val inputMetrics = context.taskMetrics
+ .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
+ val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
+ split.inputSplit.value match {
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ case _ => None
+ }
+ }
+ inputMetrics.setBytesReadCallback(bytesReadCallback)
+
+ var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
+ context.stageId, theSplit.index, context.attemptNumber, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
@@ -224,19 +240,6 @@ class HadoopRDD[K, V](
val key: K = reader.createKey()
val value: V = reader.createValue()
- // Set the task input metrics.
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- try {
- /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't
- * always at record boundaries, so tasks may need to read into other splits to complete
- * a record. */
- inputMetrics.bytesRead = split.inputSplit.value.getLength()
- } catch {
- case e: java.io.IOException =>
- logWarning("Unable to get input size to set InputMetrics for task", e)
- }
- context.taskMetrics.inputMetrics = Some(inputMetrics)
-
override def getNext() = {
try {
finished = !reader.next(key, value)
@@ -244,12 +247,26 @@ class HadoopRDD[K, V](
case eof: EOFException =>
finished = true
}
+
(key, value)
}
override def close() {
try {
reader.close()
+ if (bytesReadCallback.isDefined) {
+ inputMetrics.updateBytesRead()
+ } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
+ split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
+ // If we can't get the bytes read from the FS stats, fall back to the split size,
+ // which may be inaccurate.
+ try {
+ inputMetrics.addBytesRead(split.inputSplit.value.getLength)
+ } catch {
+ case e: java.io.IOException =>
+ logWarning("Unable to get input size to set InputMetrics for task", e)
+ }
+ }
} catch {
case e: Exception => {
if (!Utils.inShutdown()) {
@@ -292,6 +309,15 @@ class HadoopRDD[K, V](
// Do nothing. Hadoop RDD should not be checkpointed.
}
+ override def persist(storageLevel: StorageLevel): this.type = {
+ if (storageLevel.deserialized) {
+ logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" +
+ " behavior because Hadoop's RecordReader reuses the same Writable object for all records." +
+ " Use a map transformation to make copies of the records.")
+ }
+ super.persist(storageLevel)
+ }
+
def getConf: Configuration = getJobConf()
}
@@ -302,6 +328,9 @@ private[spark] object HadoopRDD extends Logging {
*/
val CONFIGURATION_INSTANTIATION_LOCK = new Object()
+ /** Update the input bytes read metric each time this number of records has been read */
+ val RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES = 256
+
/**
* The three methods below are helpers for accessing the local map, a property of the SparkEnv of
* the local process.
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 0e38f224ac81d..642a12c1edf6c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet}
import scala.reflect.ClassTag
-import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.function.{Function => JFunction}
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.util.NextIterator
+import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index = idx
@@ -125,5 +128,82 @@ object JdbcRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
-}
+ trait ConnectionFactory extends Serializable {
+ @throws[Exception]
+ def getConnection: Connection
+ }
+
+ /**
+ * Create an RDD that executes an SQL query on a JDBC connection and reads results.
+ * For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
+ *
+ * @param connectionFactory a factory that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
+ * This should only call getInt, getString, etc; the RDD takes care of calling next.
+ * The default maps a ResultSet to an array of Object.
+ */
+ def create[T](
+ sc: JavaSparkContext,
+ connectionFactory: ConnectionFactory,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int,
+ mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
+
+ val jdbcRDD = new JdbcRDD[T](
+ sc.sc,
+ () => connectionFactory.getConnection,
+ sql,
+ lowerBound,
+ upperBound,
+ numPartitions,
+ (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
+
+ new JavaRDD[T](jdbcRDD)(fakeClassTag)
+ }
+
+ /**
+ * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
+ * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
+ *
+ * @param connectionFactory a factory that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ */
+ def create(
+ sc: JavaSparkContext,
+ connectionFactory: ConnectionFactory,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int): JavaRDD[Array[Object]] = {
+
+ val mapRow = new JFunction[ResultSet, Array[Object]] {
+ override def call(resultSet: ResultSet): Array[Object] = {
+ resultSetToObjectArray(resultSet)
+ }
+ }
+
+ create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
deleted file mode 100644
index 8d7c288593665..0000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.{Partition, TaskContext}
-
-private[spark]
-class MappedRDD[U: ClassTag, T: ClassTag](prev: RDD[T], f: T => U)
- extends RDD[U](prev) {
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override def compute(split: Partition, context: TaskContext) =
- firstParent[T].iterator(split, context).map(f)
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala
deleted file mode 100644
index a60952eee5901..0000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/MappedValuesRDD.scala
+++ /dev/null
@@ -1,33 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{Partition, TaskContext}
-
-private[spark]
-class MappedValuesRDD[K, V, U](prev: RDD[_ <: Product2[K, V]], f: V => U)
- extends RDD[(K, U)](prev) {
-
- override def getPartitions = firstParent[Product2[K, U]].partitions
-
- override val partitioner = firstParent[Product2[K, U]].partitioner
-
- override def compute(split: Partition, context: TaskContext): Iterator[(K, U)] = {
- firstParent[Product2[K, V]].iterator(split, context).map { pair => (pair._1, f(pair._2)) }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 0cccdefc5ee09..44b9ffd2a53fd 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -25,17 +25,17 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.input.WholeTextFileInputFormat
-import org.apache.spark.InterruptibleIterator
-import org.apache.spark.Logging
-import org.apache.spark.Partition
-import org.apache.spark.SerializableWritable
-import org.apache.spark.{SparkContext, TaskContext}
-import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark._
+import org.apache.spark.executor.DataReadMethod
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.Utils
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.storage.StorageLevel
private[spark] class NewHadoopPartition(
rddId: Int,
@@ -105,6 +105,21 @@ class NewHadoopRDD[K, V](
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
+
+ val inputMetrics = context.taskMetrics
+ .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
+ val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
+ split.serializableHadoopSplit.value match {
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ case _ => None
+ }
+ }
+ inputMetrics.setBytesReadCallback(bytesReadCallback)
+
val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
@@ -117,22 +132,11 @@ class NewHadoopRDD[K, V](
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- try {
- /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't
- * always at record boundaries, so tasks may need to read into other splits to complete
- * a record. */
- inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength()
- } catch {
- case e: Exception =>
- logWarning("Unable to get input split size in order to set task input bytes", e)
- }
- context.taskMetrics.inputMetrics = Some(inputMetrics)
-
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
var havePair = false
var finished = false
+ var recordsSinceMetricsUpdate = 0
override def hasNext: Boolean = {
if (!finished && !havePair) {
@@ -147,12 +151,26 @@ class NewHadoopRDD[K, V](
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
+
(reader.getCurrentKey, reader.getCurrentValue)
}
private def close() {
try {
reader.close()
+ if (bytesReadCallback.isDefined) {
+ inputMetrics.updateBytesRead()
+ } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
+ split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
+ // If we can't get the bytes read from the FS stats, fall back to the split size,
+ // which may be inaccurate.
+ try {
+ inputMetrics.addBytesRead(split.serializableHadoopSplit.value.getLength)
+ } catch {
+ case e: java.io.IOException =>
+ logWarning("Unable to get input size to set InputMetrics for task", e)
+ }
+ }
} catch {
case e: Exception => {
if (!Utils.inShutdown()) {
@@ -190,6 +208,16 @@ class NewHadoopRDD[K, V](
locs.getOrElse(split.getLocations.filter(_ != "localhost"))
}
+ override def persist(storageLevel: StorageLevel): this.type = {
+ if (storageLevel.deserialized) {
+ logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" +
+ " behavior because Hadoop's RecordReader reuses the same Writable object for all records." +
+ " Use a map transformation to make copies of the records.")
+ }
+ super.persist(storageLevel)
+ }
+
+
def getConf: Configuration = confBroadcast.value.value
}
@@ -233,7 +261,7 @@ private[spark] class WholeTextFileRDD(
case _ =>
}
val jobContext = newJobContext(conf, jobId)
- inputFormat.setMaxSplitSize(jobContext, minPartitions)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index d0dbfef35d03c..144f679a59460 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -24,10 +24,9 @@ import org.apache.spark.annotation.DeveloperApi
/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
- * an implicit conversion. Import `org.apache.spark.SparkContext._` at the top of your program to
- * use these functions. They will work with any key type `K` that has an implicit `Ordering[K]` in
- * scope. Ordering objects already exist for all of the standard primitive types. Users can also
- * define their own orderings for custom types, or to override the default ordering. The implicit
+ * an implicit conversion. They will work with any key type `K` that has an implicit `Ordering[K]`
+ * in scope. Ordering objects already exist for all of the standard primitive types. Users can also
+ * define their own orderings for custom types, or to override the default ordering. The implicit
* ordering that is in the closest scope will be used.
*
* {{{
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index da89f634abaea..49b88a90ab5af 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -25,21 +25,23 @@ import scala.collection.{Map, mutable}
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
+import scala.util.DynamicVariable
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
-RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
+RecordWriter => NewRecordWriter}
import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
-import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.executor.{DataWriteMethod, OutputMetrics}
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
@@ -48,7 +50,6 @@ import org.apache.spark.util.random.StratifiedSamplingUtils
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
- * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
class PairRDDFunctions[K, V](self: RDD[(K, V)])
(implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null)
@@ -84,7 +85,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
- val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ val aggregator = new Aggregator[K, V, C](
+ self.context.clean(createCombiner),
+ self.context.clean(mergeValue),
+ self.context.clean(mergeCombiners))
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(iter => {
val context = TaskContext.get()
@@ -120,11 +124,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U,
combOp: (U, U) => U): RDD[(K, U)] = {
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
- val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)
- lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))
combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
@@ -165,12 +169,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
- val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)
// When deserializing, use a lazy val to create just one instance of the serializer per task
- lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
@@ -433,6 +437,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
* or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ *
+ * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any
+ * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]].
*/
def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = {
// groupByKey shouldn't use map side combine because map side combine does not
@@ -454,6 +461,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
* or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ *
+ * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any
+ * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]].
*/
def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = {
groupByKey(new HashPartitioner(numPartitions))
@@ -480,7 +490,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues( pair =>
- for (v <- pair._1; w <- pair._2) yield (v, w)
+ for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w)
)
}
@@ -493,9 +503,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = {
this.cogroup(other, partitioner).flatMapValues { pair =>
if (pair._2.isEmpty) {
- pair._1.map(v => (v, None))
+ pair._1.iterator.map(v => (v, None))
} else {
- for (v <- pair._1; w <- pair._2) yield (v, Some(w))
+ for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, Some(w))
}
}
}
@@ -510,9 +520,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
: RDD[(K, (Option[V], W))] = {
this.cogroup(other, partitioner).flatMapValues { pair =>
if (pair._1.isEmpty) {
- pair._2.map(w => (None, w))
+ pair._2.iterator.map(w => (None, w))
} else {
- for (v <- pair._1; w <- pair._2) yield (Some(v), w)
+ for (v <- pair._1.iterator; w <- pair._2.iterator) yield (Some(v), w)
}
}
}
@@ -528,9 +538,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def fullOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner)
: RDD[(K, (Option[V], Option[W]))] = {
this.cogroup(other, partitioner).flatMapValues {
- case (vs, Seq()) => vs.map(v => (Some(v), None))
- case (Seq(), ws) => ws.map(w => (None, Some(w)))
- case (vs, ws) => for (v <- vs; w <- ws) yield (Some(v), Some(w))
+ case (vs, Seq()) => vs.iterator.map(v => (Some(v), None))
+ case (Seq(), ws) => ws.iterator.map(w => (None, Some(w)))
+ case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), Some(w))
}
}
@@ -660,7 +670,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*/
def mapValues[U](f: V => U): RDD[(K, U)] = {
val cleanF = self.context.clean(f)
- new MappedValuesRDD(self, cleanF)
+ new MapPartitionsRDD[(K, U), (K, V)](self,
+ (context, pid, iter) => iter.map { case (k, v) => (k, cleanF(v)) },
+ preservesPartitioning = true)
}
/**
@@ -669,7 +681,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*/
def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = {
val cleanF = self.context.clean(f)
- new FlatMappedValuesRDD(self, cleanF)
+ new MapPartitionsRDD[(K, U), (K, V)](self,
+ (context, pid, iter) => iter.flatMap { case (k, v) =>
+ cleanF(v).map(x => (k, x))
+ },
+ preservesPartitioning = true)
}
/**
@@ -955,36 +971,43 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val outfmt = job.getOutputFormatClass
val jobFormat = outfmt.newInstance
- if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) {
+ if (isOutputSpecValidationEnabled) {
// FileOutputFormat ignores the filesystem parameter
jobFormat.checkOutputSpecs(job)
}
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
- // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
- // around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ val config = wrappedConf.value
/* "reduce task" */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
- attemptNumber)
- val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
+ context.attemptNumber)
+ val hadoopContext = newTaskAttemptContext(config, attemptId)
val format = outfmt.newInstance
format match {
- case c: Configurable => c.setConf(wrappedConf.value)
+ case c: Configurable => c.setConf(config)
case _ => ()
}
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
+
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
+
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
try {
+ var recordsWritten = 0L
while (iter.hasNext) {
val pair = iter.next()
writer.write(pair._1, pair._2)
+
+ // Update bytes written metric every few records
+ maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+ recordsWritten += 1
}
} finally {
writer.close(hadoopContext)
}
committer.commitTask(hadoopContext)
+ bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
1
} : Int
@@ -1005,6 +1028,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def saveAsHadoopDataset(conf: JobConf) {
// Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038).
val hadoopConf = conf
+ val wrappedConf = new SerializableWritable(hadoopConf)
val outputFormatInstance = hadoopConf.getOutputFormat
val keyClass = hadoopConf.getOutputKeyClass
val valueClass = hadoopConf.getOutputValueClass
@@ -1022,7 +1046,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
valueClass.getSimpleName + ")")
- if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) {
+ if (isOutputSpecValidationEnabled) {
// FileOutputFormat ignores the filesystem parameter
val ignoredFs = FileSystem.get(hadoopConf)
hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf)
@@ -1032,27 +1056,53 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.preSetup()
val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => {
+ val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
+
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
- writer.setup(context.stageId, context.partitionId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
try {
+ var recordsWritten = 0L
while (iter.hasNext) {
val record = iter.next()
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
+
+ // Update bytes written metric every few records
+ maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+ recordsWritten += 1
}
} finally {
writer.close()
}
writer.commit()
+ bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
}
self.context.runJob(self, writeToFile)
writer.commitJob()
}
+ private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = {
+ val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
+ val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
+ if (bytesWrittenCallback.isDefined) {
+ context.taskMetrics.outputMetrics = Some(outputMetrics)
+ }
+ (outputMetrics, bytesWrittenCallback)
+ }
+
+ private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long],
+ outputMetrics: OutputMetrics, recordsWritten: Long): Unit = {
+ if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0
+ && bytesWrittenCallback.isDefined) {
+ bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
+ }
+ }
+
/**
* Return an RDD with the keys of each tuple.
*/
@@ -1068,4 +1118,22 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
private[spark] def valueClass: Class[_] = vt.runtimeClass
private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord)
+
+ // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation
+ // setting can take effect:
+ private def isOutputSpecValidationEnabled: Boolean = {
+ val validationDisabled = PairRDDFunctions.disableOutputSpecValidation.value
+ val enabledInConf = self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)
+ enabledInConf && !validationDisabled
+ }
+}
+
+private[spark] object PairRDDFunctions {
+ val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256
+
+ /**
+ * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case
+ * basis; see SPARK-4835 for more details.
+ */
+ val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 66c71bf7e8bb5..f12d0cffaba34 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -48,7 +48,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
override def index: Int = slice
@throws(classOf[IOException])
- private def writeObject(out: ObjectOutputStream): Unit = {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
@@ -67,7 +67,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
}
@throws(classOf[IOException])
- private def readObject(in: ObjectInputStream): Unit = {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
@@ -111,7 +111,8 @@ private object ParallelCollectionRDD {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
- * it efficient to run Spark over RDDs representing large sets of numbers.
+ * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
+ * is an inclusive Range, we use inclusive range for the last slice.
*/
def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1) {
@@ -127,19 +128,15 @@ private object ParallelCollectionRDD {
})
}
seq match {
- case r: Range.Inclusive => {
- val sign = if (r.step < 0) {
- -1
- } else {
- 1
- }
- slice(new Range(
- r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
- }
case r: Range => {
- positions(r.length, numSlices).map({
- case (start, end) =>
+ positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) =>
+ // If the range is inclusive, use inclusive range for the last slice
+ if (r.isInclusive && index == numSlices - 1) {
+ new Range.Inclusive(r.start + start * r.step, r.end, r.step)
+ }
+ else {
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
+ }
}).toSeq.asInstanceOf[Seq[Seq[T]]]
}
case nr: NumericRange[_] => {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
index 0c2cd7a24783b..92b0641d0fb6e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
+import org.apache.spark.util.Utils
/**
* Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of
@@ -38,7 +39,7 @@ class PartitionerAwareUnionRDDPartition(
override def hashCode(): Int = idx
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent partition at the time of task serialization
parents = rdds.map(_.partitions(index)).toArray
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 56ac7a69be0d3..ed79032893d33 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -63,7 +63,7 @@ private[spark] class PipedRDD[T: ClassTag](
/**
* A FilenameFilter that accepts anything that isn't equal to the name passed in.
- * @param name of file or directory to leave out
+ * @param filterName of file or directory to leave out
*/
class NotEqualsFileNameFilter(filterName: String) extends FilenameFilter {
def accept(dir: File, name: String): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index b7f125d01dfaf..fe55a5124f3b6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -17,33 +17,31 @@
package org.apache.spark.rdd
-import java.util.{Properties, Random}
+import java.util.Random
import scala.collection.{mutable, Map}
import scala.collection.mutable.ArrayBuffer
+import scala.language.implicitConversions
import scala.reflect.{classTag, ClassTag}
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
-import org.apache.hadoop.io.BytesWritable
+import org.apache.hadoop.io.{Writable, BytesWritable, NullWritable, Text}
import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.hadoop.io.NullWritable
-import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.spark._
import org.apache.spark.Partitioner._
-import org.apache.spark.SparkContext._
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.broadcast.Broadcast
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite}
+import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
-import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
+import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler,
+ SamplingUtils}
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -55,8 +53,8 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingU
* Doubles; and
* [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that
* can be saved as SequenceFiles.
- * These operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)]
- * through implicit conversions when you `import org.apache.spark.SparkContext._`.
+ * All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)]
+ * through implicit.
*
* Internally, each RDD is characterized by five main properties:
*
@@ -74,10 +72,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingU
* on RDD internals.
*/
abstract class RDD[T: ClassTag](
- @transient private var sc: SparkContext,
+ @transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
+ if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have defined nested RDDs without running jobs with them.
+ logWarning("Spark does not support nested RDDs (see SPARK-5063)")
+ }
+
+ private def sc: SparkContext = {
+ if (_sc == null) {
+ throw new SparkException(
+ "RDD transformations and actions can only be invoked by the driver, not inside of other " +
+ "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
+ "the values transformation and count action cannot be performed inside of the rdd1.map " +
+ "transformation. For more information, see SPARK-5063.")
+ }
+ _sc
+ }
+
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
@@ -267,19 +282,30 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
- def map[U: ClassTag](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
+ def map[U: ClassTag](f: T => U): RDD[U] = {
+ val cleanF = sc.clean(f)
+ new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.map(cleanF))
+ }
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
- def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] =
- new FlatMappedRDD(this, sc.clean(f))
+ def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = {
+ val cleanF = sc.clean(f)
+ new MapPartitionsRDD[U, T](this, (context, pid, iter) => iter.flatMap(cleanF))
+ }
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
- def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f))
+ def filter(f: T => Boolean): RDD[T] = {
+ val cleanF = sc.clean(f)
+ new MapPartitionsRDD[T, T](
+ this,
+ (context, pid, iter) => iter.filter(cleanF),
+ preservesPartitioning = true)
+ }
/**
* Return a new RDD containing the distinct elements in this RDD.
@@ -375,7 +401,8 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
+ new PartitionwiseSampledRDD[T, T](
+ this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
}.toArray
}
@@ -499,7 +526,9 @@ abstract class RDD[T: ClassTag](
/**
* Return an RDD created by coalescing all elements within each partition into an array.
*/
- def glom(): RDD[Array[T]] = new GlommedRDD(this)
+ def glom(): RDD[Array[T]] = {
+ new MapPartitionsRDD[Array[T], T](this, (context, pid, iter) => Iterator(iter.toArray))
+ }
/**
* Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
@@ -571,8 +600,8 @@ abstract class RDD[T: ClassTag](
* print line function (like out.println()) as the 2nd parameter.
* An example of pipe the RDD data of groupBy() in a streaming way,
* instead of constructing a huge String to concat all the elements:
- * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
- * for (e <- record._2){f(e)}
+ * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
+ * for (e <- record._2){f(e)}
* @param separateWorkingDir Use separate working directories for each task.
* @return the result RDD
*/
@@ -808,7 +837,7 @@ abstract class RDD[T: ClassTag](
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtract(other: RDD[T]): RDD[T] =
subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
@@ -867,6 +896,38 @@ abstract class RDD[T: ClassTag](
jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
+ /**
+ * Reduces the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree (default: 2)
+ * @see [[org.apache.spark.rdd.RDD#reduce]]
+ */
+ def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
+ require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+ val cleanF = context.clean(f)
+ val reducePartition: Iterator[T] => Option[T] = iter => {
+ if (iter.hasNext) {
+ Some(iter.reduceLeft(cleanF))
+ } else {
+ None
+ }
+ }
+ val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it)))
+ val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
+ if (c.isDefined && x.isDefined) {
+ Some(cleanF(c.get, x.get))
+ } else if (c.isDefined) {
+ c
+ } else if (x.isDefined) {
+ x
+ } else {
+ None
+ }
+ }
+ partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth)
+ .getOrElse(throw new UnsupportedOperationException("empty collection"))
+ }
+
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
@@ -902,6 +963,37 @@ abstract class RDD[T: ClassTag](
jobResult
}
+ /**
+ * Aggregates the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree (default: 2)
+ * @see [[org.apache.spark.rdd.RDD#aggregate]]
+ */
+ def treeAggregate[U: ClassTag](zeroValue: U)(
+ seqOp: (U, T) => U,
+ combOp: (U, U) => U,
+ depth: Int = 2): U = {
+ require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+ if (partitions.size == 0) {
+ return Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
+ }
+ val cleanSeqOp = context.clean(seqOp)
+ val cleanCombOp = context.clean(combOp)
+ val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+ var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
+ var numPartitions = partiallyAggregated.partitions.size
+ val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
+ // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
+ while (numPartitions > scale + numPartitions / scale) {
+ numPartitions /= scale
+ val curNumPartitions = numPartitions
+ partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
+ iter.map((i % curNumPartitions, _))
+ }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
+ }
+ partiallyAggregated.reduce(cleanCombOp)
+ }
+
/**
* Return the number of elements in the RDD.
*/
@@ -931,7 +1023,7 @@ abstract class RDD[T: ClassTag](
*
* Note that this method should only be used if the resulting map is expected to be small, as
* the whole thing is loaded into the driver's memory.
- * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which
+ * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which
* returns an RDD[T, Long] instead of a map.
*/
def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = {
@@ -969,7 +1061,7 @@ abstract class RDD[T: ClassTag](
* Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available
* here.
*
- * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p`
+ * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p`
* would trigger sparse representation of registers, which may reduce the memory consumption
* and increase accuracy when the cardinality is small.
*
@@ -1094,7 +1186,7 @@ abstract class RDD[T: ClassTag](
}
/**
- * Returns the top K (largest) elements from this RDD as defined by the specified
+ * Returns the top k (largest) elements from this RDD as defined by the specified
* implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example:
* {{{
* sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1)
@@ -1104,14 +1196,14 @@ abstract class RDD[T: ClassTag](
* // returns Array(6, 5)
* }}}
*
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse)
/**
- * Returns the first K (smallest) elements from this RDD as defined by the specified
+ * Returns the first k (smallest) elements from this RDD as defined by the specified
* implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]].
* For example:
* {{{
@@ -1122,7 +1214,7 @@ abstract class RDD[T: ClassTag](
* // returns Array(2, 3)
* }}}
*
- * @param num the number of top elements to return
+ * @param num k, the number of elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
@@ -1130,15 +1222,20 @@ abstract class RDD[T: ClassTag](
if (num == 0) {
Array.empty
} else {
- mapPartitions { items =>
+ val mapRDDs = mapPartitions { items =>
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
Iterator.single(queue)
- }.reduce { (queue1, queue2) =>
- queue1 ++= queue2
- queue1
- }.toArray.sorted(ord)
+ }
+ if (mapRDDs.partitions.size == 0) {
+ Array.empty
+ } else {
+ mapRDDs.reduce { (queue1, queue2) =>
+ queue1 ++= queue2
+ queue1
+ }.toArray.sorted(ord)
+ }
}
}
@@ -1154,11 +1251,36 @@ abstract class RDD[T: ClassTag](
* */
def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min)
+ /**
+ * @return true if and only if the RDD contains no elements at all. Note that an RDD
+ * may be empty even when it has at least 1 partition.
+ */
+ def isEmpty(): Boolean = partitions.length == 0 || take(1).length == 0
+
/**
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) {
- this.map(x => (NullWritable.get(), new Text(x.toString)))
+ // https://issues.apache.org/jira/browse/SPARK-2075
+ //
+ // NullWritable is a `Comparable` in Hadoop 1.+, so the compiler cannot find an implicit
+ // Ordering for it and will use the default `null`. However, it's a `Comparable[NullWritable]`
+ // in Hadoop 2.+, so the compiler will call the implicit `Ordering.ordered` method to create an
+ // Ordering for `NullWritable`. That's why the compiler will generate different anonymous
+ // classes for `saveAsTextFile` in Hadoop 1.+ and Hadoop 2.+.
+ //
+ // Therefore, here we provide an explicit Ordering `null` to make sure the compiler generate
+ // same bytecodes for `saveAsTextFile`.
+ val nullWritableClassTag = implicitly[ClassTag[NullWritable]]
+ val textClassTag = implicitly[ClassTag[Text]]
+ val r = this.mapPartitions { iter =>
+ val text = new Text()
+ iter.map { x =>
+ text.set(x.toString)
+ (NullWritable.get(), text)
+ }
+ }
+ RDD.rddToPairRDDFunctions(r)(nullWritableClassTag, textClassTag, null)
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
}
@@ -1166,7 +1288,17 @@ abstract class RDD[T: ClassTag](
* Save this RDD as a compressed text file, using string representations of elements.
*/
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
- this.map(x => (NullWritable.get(), new Text(x.toString)))
+ // https://issues.apache.org/jira/browse/SPARK-2075
+ val nullWritableClassTag = implicitly[ClassTag[NullWritable]]
+ val textClassTag = implicitly[ClassTag[Text]]
+ val r = this.mapPartitions { iter =>
+ val text = new Text()
+ iter.map { x =>
+ text.set(x.toString)
+ (NullWritable.get(), text)
+ }
+ }
+ RDD.rddToPairRDDFunctions(r)(nullWritableClassTag, textClassTag, null)
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
}
@@ -1200,7 +1332,7 @@ abstract class RDD[T: ClassTag](
*/
def checkpoint() {
if (context.checkpointDir.isEmpty) {
- throw new Exception("Checkpoint directory has not been set in the SparkContext")
+ throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
checkpointData = Some(new RDDCheckpointData(this))
checkpointData.get.markForCheckpoint()
@@ -1247,7 +1379,7 @@ abstract class RDD[T: ClassTag](
/**
* Private API for changing an RDD's ClassTag.
- * Used for internal Java <-> Scala API compatibility.
+ * Used for internal Java-Scala API compatibility.
*/
private[spark] def retag(cls: Class[T]): RDD[T] = {
val classTag: ClassTag[T] = ClassTag.apply(cls)
@@ -1256,7 +1388,7 @@ abstract class RDD[T: ClassTag](
/**
* Private API for changing an RDD's ClassTag.
- * Used for internal Java <-> Scala API compatibility.
+ * Used for internal Java-Scala API compatibility.
*/
private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = {
this.mapPartitions(identity, preservesPartitioning = true)(classTag)
@@ -1307,7 +1439,7 @@ abstract class RDD[T: ClassTag](
def debugSelf (rdd: RDD[_]): Seq[String] = {
import Utils.bytesToString
- val persistence = storageLevel.description
+ val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else ""
val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info =>
" CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format(
info.numCachedPartitions, bytesToString(info.memSize),
@@ -1381,3 +1513,52 @@ abstract class RDD[T: ClassTag](
new JavaRDD(this)(elementClassTag)
}
}
+
+
+/**
+ * Defines implicit functions that provide extra functionalities on RDDs of specific types.
+ *
+ * For example, [[RDD.rddToPairRDDFunctions]] converts an RDD into a [[PairRDDFunctions]] for
+ * key-value-pair RDDs, and enabling extra functionalities such as [[PairRDDFunctions.reduceByKey]].
+ */
+object RDD {
+
+ // The following implicit functions were in SparkContext before 1.3 and users had to
+ // `import SparkContext._` to enable them. Now we move them here to make the compiler find
+ // them automatically. However, we still keep the old functions in SparkContext for backward
+ // compatibility and forward to the following functions directly.
+
+ implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)])
+ (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): PairRDDFunctions[K, V] = {
+ new PairRDDFunctions(rdd)
+ }
+
+ implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]): AsyncRDDActions[T] = {
+ new AsyncRDDActions(rdd)
+ }
+
+ implicit def rddToSequenceFileRDDFunctions[K, V](rdd: RDD[(K, V)])
+ (implicit kt: ClassTag[K], vt: ClassTag[V],
+ keyWritableFactory: WritableFactory[K],
+ valueWritableFactory: WritableFactory[V])
+ : SequenceFileRDDFunctions[K, V] = {
+ implicit val keyConverter = keyWritableFactory.convert
+ implicit val valueConverter = valueWritableFactory.convert
+ new SequenceFileRDDFunctions(rdd,
+ keyWritableFactory.writableClass(kt), valueWritableFactory.writableClass(vt))
+ }
+
+ implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](rdd: RDD[(K, V)])
+ : OrderedRDDFunctions[K, V, (K, V)] = {
+ new OrderedRDDFunctions[K, V, (K, V)](rdd)
+ }
+
+ implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]): DoubleRDDFunctions = {
+ new DoubleRDDFunctions(rdd)
+ }
+
+ implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T])
+ : DoubleRDDFunctions = {
+ new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
index b097c30f8c231..9e8cee5331cf8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -21,8 +21,7 @@ import java.util.Random
import scala.reflect.ClassTag
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.{Partition, TaskContext}
@@ -53,9 +52,11 @@ private[spark] class SampledRDD[T: ClassTag](
if (withReplacement) {
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
- val poisson = new Poisson(frac, new DRand(split.seed))
+ val poisson = new PoissonDistribution(frac)
+ poisson.reseedRandomGenerator(split.seed)
+
firstParent[T].iterator(split.prev, context).flatMap { element =>
- val count = poisson.nextInt()
+ val count = poisson.sample()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
} else {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
index 9a1efc83cbe6a..059f8963691f0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
@@ -24,20 +24,41 @@ import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.spark.Logging
-import org.apache.spark.SparkContext._
/**
* Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile,
* through an implicit conversion. Note that this can't be part of PairRDDFunctions because
* we need more implicit parameters to convert our keys and values to Writable.
*
- * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions.
*/
class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag](
- self: RDD[(K, V)])
+ self: RDD[(K, V)],
+ _keyWritableClass: Class[_ <: Writable],
+ _valueWritableClass: Class[_ <: Writable])
extends Logging
with Serializable {
+ @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0")
+ def this(self: RDD[(K, V)]) {
+ this(self, null, null)
+ }
+
+ private val keyWritableClass =
+ if (_keyWritableClass == null) {
+ // pre 1.3.0, we need to use Reflection to get the Writable class
+ getWritableClass[K]()
+ } else {
+ _keyWritableClass
+ }
+
+ private val valueWritableClass =
+ if (_valueWritableClass == null) {
+ // pre 1.3.0, we need to use Reflection to get the Writable class
+ getWritableClass[V]()
+ } else {
+ _valueWritableClass
+ }
+
private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = {
val c = {
if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) {
@@ -56,6 +77,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
c.asInstanceOf[Class[_ <: Writable]]
}
+
/**
* Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key
* and value types. If the key or value are Writable, then we use their classes directly;
@@ -66,26 +88,28 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
def anyToWritable[U <% Writable](u: U): Writable = u
- val keyClass = getWritableClass[K]
- val valueClass = getWritableClass[V]
- val convertKey = !classOf[Writable].isAssignableFrom(self.keyClass)
- val convertValue = !classOf[Writable].isAssignableFrom(self.valueClass)
+ // TODO We cannot force the return type of `anyToWritable` be same as keyWritableClass and
+ // valueWritableClass at the compile time. To implement that, we need to add type parameters to
+ // SequenceFileRDDFunctions. however, SequenceFileRDDFunctions is a public class so it will be a
+ // breaking change.
+ val convertKey = self.keyClass != keyWritableClass
+ val convertValue = self.valueClass != valueWritableClass
- logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," +
- valueClass.getSimpleName + ")" )
+ logInfo("Saving as sequence file of type (" + keyWritableClass.getSimpleName + "," +
+ valueWritableClass.getSimpleName + ")" )
val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
val jobConf = new JobConf(self.context.hadoopConfiguration)
if (!convertKey && !convertValue) {
- self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
+ self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec)
} else if (!convertKey && convertValue) {
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
- path, keyClass, valueClass, format, jobConf, codec)
+ path, keyWritableClass, valueWritableClass, format, jobConf, codec)
} else if (convertKey && !convertValue) {
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
- path, keyClass, valueClass, format, jobConf, codec)
+ path, keyWritableClass, valueWritableClass, format, jobConf, codec)
} else if (convertKey && convertValue) {
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
- path, keyClass, valueClass, format, jobConf, codec)
+ path, keyWritableClass, valueWritableClass, format, jobConf, codec)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 0c97eb0aaa51f..aece683ff3199 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
/**
* Partition for UnionRDD.
@@ -48,7 +49,7 @@ private[spark] class UnionPartition[T: ClassTag](
override val index: Int = idx
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
parentPartition = rdd.partitions(parentRddPartitionIndex)
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index f3d30f6c9b32f..95b2dd954e9f4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
+import org.apache.spark.util.Utils
private[spark] class ZippedPartitionsPartition(
idx: Int,
@@ -34,7 +35,7 @@ private[spark] class ZippedPartitionsPartition(
def partitions = partitionValues
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
partitionValues = rdds.map(rdd => rdd.partitions(idx))
oos.defaultWriteObject()
@@ -76,7 +77,7 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
sc: SparkContext,
- f: (Iterator[A], Iterator[B]) => Iterator[V],
+ var f: (Iterator[A], Iterator[B]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
preservesPartitioning: Boolean = false)
@@ -91,13 +92,14 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
super.clearDependencies()
rdd1 = null
rdd2 = null
+ f = null
}
}
private[spark] class ZippedPartitionsRDD3
[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
sc: SparkContext,
- f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
+ var f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
@@ -116,13 +118,14 @@ private[spark] class ZippedPartitionsRDD3
rdd1 = null
rdd2 = null
rdd3 = null
+ f = null
}
}
private[spark] class ZippedPartitionsRDD4
[A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag](
sc: SparkContext,
- f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
+ var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
@@ -144,5 +147,6 @@ private[spark] class ZippedPartitionsRDD4
rdd2 = null
rdd3 = null
rdd4 = null
+ f = null
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
index e2c301603b4a5..8c43a559409f2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
@@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
private[spark]
class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) {
- override def getPartitions: Array[Partition] = {
+ /** The start index of each partition. */
+ @transient private val startIndices: Array[Long] = {
val n = prev.partitions.size
- val startIndices: Array[Long] =
- if (n == 0) {
- Array[Long]()
- } else if (n == 1) {
- Array(0L)
- } else {
- prev.context.runJob(
- prev,
- Utils.getIteratorSize _,
- 0 until n - 1, // do not need to count the last partition
- false
- ).scanLeft(0L)(_ + _)
- }
+ if (n == 0) {
+ Array[Long]()
+ } else if (n == 1) {
+ Array(0L)
+ } else {
+ prev.context.runJob(
+ prev,
+ Utils.getIteratorSize _,
+ 0 until n - 1, // do not need to count the last partition
+ allowLocal = false
+ ).scanLeft(0L)(_ + _)
+ }
+ }
+
+ override def getPartitions: Array[Partition] = {
firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index f81fa6d8089fc..1cfe98673773a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.util.Properties
+import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
@@ -28,8 +29,6 @@ import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal
-import akka.actor._
-import akka.actor.SupervisorStrategy.Stop
import akka.pattern.ask
import akka.util.Timeout
@@ -39,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
-import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
+import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
/**
@@ -67,8 +66,6 @@ class DAGScheduler(
clock: Clock = SystemClock)
extends Logging {
- import DAGScheduler._
-
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
this(
sc,
@@ -112,38 +109,31 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]
- private val dagSchedulerActorSupervisor =
- env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))
-
// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
- private[scheduler] var eventProcessActor: ActorRef = _
/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
- private def initializeEventProcessActor() {
- // blocking the thread until supervisor is started, which ensures eventProcessActor is
- // not null before any job is submitted
- implicit val timeout = Timeout(30 seconds)
- val initEventActorReply =
- dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this))
- eventProcessActor = Await.result(initEventActorReply, timeout.duration).
- asInstanceOf[ActorRef]
- }
+ /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
+ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
- initializeEventProcessActor()
+ private val messageScheduler =
+ Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message"))
+
+ private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
+ taskScheduler.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
- eventProcessActor ! BeginEvent(task, taskInfo)
+ eventProcessLoop.post(BeginEvent(task, taskInfo))
}
// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(taskInfo: TaskInfo) {
- eventProcessActor ! GettingResultEvent(taskInfo)
+ eventProcessLoop.post(GettingResultEvent(taskInfo))
}
// Called by TaskScheduler to report task completions or failures.
@@ -154,7 +144,8 @@ class DAGScheduler(
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
- eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
+ eventProcessLoop.post(
+ CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
}
/**
@@ -176,18 +167,18 @@ class DAGScheduler(
// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
- eventProcessActor ! ExecutorLost(execId)
+ eventProcessLoop.post(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
def executorAdded(execId: String, host: String) {
- eventProcessActor ! ExecutorAdded(execId, host)
+ eventProcessLoop.post(ExecutorAdded(execId, host))
}
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
def taskSetFailed(taskSet: TaskSet, reason: String) {
- eventProcessActor ! TaskSetFailed(taskSet, reason)
+ eventProcessLoop.post(TaskSetFailed(taskSet, reason))
}
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
@@ -446,7 +437,6 @@ class DAGScheduler(
}
// data structures based on StageId
stageIdToStage -= stageId
-
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
@@ -493,8 +483,8 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
- eventProcessActor ! JobSubmitted(
- jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
+ eventProcessLoop.post(JobSubmitted(
+ jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties))
waiter
}
@@ -534,8 +524,8 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
- eventProcessActor ! JobSubmitted(
- jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
+ eventProcessLoop.post(JobSubmitted(
+ jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -544,19 +534,19 @@ class DAGScheduler(
*/
def cancelJob(jobId: Int) {
logInfo("Asked to cancel job " + jobId)
- eventProcessActor ! JobCancelled(jobId)
+ eventProcessLoop.post(JobCancelled(jobId))
}
def cancelJobGroup(groupId: String) {
logInfo("Asked to cancel job group " + groupId)
- eventProcessActor ! JobGroupCancelled(groupId)
+ eventProcessLoop.post(JobGroupCancelled(groupId))
}
/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs() {
- eventProcessActor ! AllJobsCancelled
+ eventProcessLoop.post(AllJobsCancelled)
}
private[scheduler] def doCancelAllJobs() {
@@ -572,7 +562,7 @@ class DAGScheduler(
* Cancel all jobs associated with a running or scheduled stage.
*/
def cancelStage(stageId: Int) {
- eventProcessActor ! StageCancelled(stageId)
+ eventProcessLoop.post(StageCancelled(stageId))
}
/**
@@ -632,8 +622,8 @@ class DAGScheduler(
try {
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
- val taskContext =
- new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true)
+ val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
+ attemptNumber = 0, runningLocally = true)
TaskContextHelper.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
@@ -658,7 +648,7 @@ class DAGScheduler(
// completion events or stage abort
stageIdToStage -= s.id
jobIdToStageIds -= job.jobId
- listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), jobResult))
}
}
@@ -707,7 +697,7 @@ class DAGScheduler(
stage.latestInfo.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
}
- listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
}
}
@@ -746,16 +736,20 @@ class DAGScheduler(
logInfo("Missing parents: " + getMissingParentStages(finalStage))
val shouldRunLocally =
localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
+ val jobSubmissionTime = clock.getTime()
if (shouldRunLocally) {
// Compute very short actions like first() or take() with no parent stages locally.
- listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties))
+ listenerBus.post(
+ SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties))
runLocally(job)
} else {
jobIdToActiveJob(jobId) = job
activeJobs += job
finalStage.resultOfJob = Some(job)
- listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray,
- properties))
+ val stageIds = jobIdToStageIds(jobId).toArray
+ val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
+ listenerBus.post(
+ SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
submitStage(finalStage)
}
}
@@ -862,26 +856,6 @@ class DAGScheduler(
}
if (tasks.size > 0) {
- // Preemptively serialize a task to make sure it can be serialized. We are catching this
- // exception here because it would be fairly hard to catch the non-serializable exception
- // down the road, where we have several different implementations for local scheduler and
- // cluster schedulers.
- //
- // We've already serialized RDDs and closures in taskBinary, but here we check for all other
- // objects such as Partition.
- try {
- closureSerializer.serialize(tasks.head)
- } catch {
- case e: NotSerializableException =>
- abortStage(stage, "Task not serializable: " + e.toString)
- runningStages -= stage
- return
- case NonFatal(e) => // Other exceptions, such as IllegalArgumentException from Kryo.
- abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
- runningStages -= stage
- return
- }
-
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
stage.pendingTasks ++= tasks
logDebug("New pending tasks: " + stage.pendingTasks)
@@ -898,6 +872,34 @@ class DAGScheduler(
}
}
+ /** Merge updates from a task to our local accumulator values */
+ private def updateAccumulators(event: CompletionEvent): Unit = {
+ val task = event.task
+ val stage = stageIdToStage(task.stageId)
+ if (event.accumUpdates != null) {
+ try {
+ Accumulators.add(event.accumUpdates)
+ event.accumUpdates.foreach { case (id, partialValue) =>
+ val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
+ // To avoid UI cruft, ignore cases where value wasn't updated
+ if (acc.name.isDefined && partialValue != acc.zero) {
+ val name = acc.name.get
+ val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
+ val stringValue = Accumulators.stringifyValue(acc.value)
+ stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
+ event.taskInfo.accumulables +=
+ AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
+ }
+ }
+ } catch {
+ // If we see an exception during accumulator update, just log the
+ // error and move on.
+ case e: Exception =>
+ logError(s"Failed to update accumulators for $task", e)
+ }
+ }
+ }
+
/**
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
@@ -938,27 +940,6 @@ class DAGScheduler(
}
event.reason match {
case Success =>
- if (event.accumUpdates != null) {
- try {
- Accumulators.add(event.accumUpdates)
- event.accumUpdates.foreach { case (id, partialValue) =>
- val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
- // To avoid UI cruft, ignore cases where value wasn't updated
- if (acc.name.isDefined && partialValue != acc.zero) {
- val name = acc.name.get
- val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
- val stringValue = Accumulators.stringifyValue(acc.value)
- stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
- event.taskInfo.accumulables +=
- AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
- }
- }
- } catch {
- // If we see an exception during accumulator update, just log the error and move on.
- case e: Exception =>
- logError(s"Failed to update accumulators for $task", e)
- }
- }
listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
event.reason, event.taskInfo, event.taskMetrics))
stage.pendingTasks -= task
@@ -967,13 +948,15 @@ class DAGScheduler(
stage.resultOfJob match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
+ updateAccumulators(event)
job.finished(rt.outputId) = true
job.numFinished += 1
// If the whole job has finished, remove it
if (job.numFinished == job.numPartitions) {
markStageAsFinished(stage)
cleanupStateForJobAndIndependentStages(job)
- listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
+ listenerBus.post(
+ SparkListenerJobEnd(job.jobId, clock.getTime(), JobSucceeded))
}
// taskSucceeded runs some user code that might throw an exception. Make sure
@@ -991,6 +974,7 @@ class DAGScheduler(
}
case smt: ShuffleMapTask =>
+ updateAccumulators(event)
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
@@ -1050,7 +1034,7 @@ class DAGScheduler(
logInfo("Resubmitted " + task + ", so marking it as still running")
stage.pendingTasks += task
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
@@ -1060,24 +1044,24 @@ class DAGScheduler(
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some("Fetch failure"))
+ markStageAsFinished(failedStage, Some(failureMessage))
runningStages -= failedStage
}
- if (failedStages.isEmpty && eventProcessActor != null) {
+ if (disallowStageRetryForTest) {
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ } else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
- // in that case the event will already have been scheduled. eventProcessActor may be
- // null during unit tests.
+ // in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
- import env.actorSystem.dispatcher
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
- env.actorSystem.scheduler.scheduleOnce(
- RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages)
+ messageScheduler.schedule(new Runnable {
+ override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+ }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
-
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
@@ -1086,10 +1070,10 @@ class DAGScheduler(
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, Some(task.epoch))
+ handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
- case ExceptionFailure(className, description, stackTrace, metrics) =>
+ case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
case TaskResultLost =>
@@ -1106,25 +1090,35 @@ class DAGScheduler(
* Responds to an executor being lost. This is called inside the event loop, so it assumes it can
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
*
+ * We will also assume that we've lost all shuffle blocks associated with the executor if the
+ * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed
+ * occurred, in which case we presume all shuffle data related to this executor to be lost.
+ *
* Optionally the epoch during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) {
+ private[scheduler] def handleExecutorLost(
+ execId: String,
+ fetchFailed: Boolean,
+ maybeEpoch: Option[Long] = None) {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
failedEpoch(execId) = currentEpoch
logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
blockManagerMaster.removeExecutor(execId)
- // TODO: This will be really slow if we keep accumulating shuffle map stages
- for ((shuffleId, stage) <- shuffleToMapStage) {
- stage.removeOutputsOnExecutor(execId)
- val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
- mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
- }
- if (shuffleToMapStage.isEmpty) {
- mapOutputTracker.incrementEpoch()
+
+ if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) {
+ // TODO: This will be really slow if we keep accumulating shuffle map stages
+ for ((shuffleId, stage) <- shuffleToMapStage) {
+ stage.removeOutputsOnExecutor(execId)
+ val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
+ mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
+ }
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementEpoch()
+ }
+ clearCacheLocs()
}
- clearCacheLocs()
} else {
logDebug("Additional executor lost message for " + execId +
"(epoch " + currentEpoch + ")")
@@ -1230,7 +1224,7 @@ class DAGScheduler(
if (ableToCancelStages) {
job.listener.jobFailed(error)
cleanupStateForJobAndIndependentStages(job)
- listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
}
}
@@ -1322,46 +1316,21 @@ class DAGScheduler(
def stop() {
logInfo("Stopping DAGScheduler")
- dagSchedulerActorSupervisor ! PoisonPill
+ eventProcessLoop.stop()
taskScheduler.stop()
}
-}
-
-private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler)
- extends Actor with Logging {
-
- override val supervisorStrategy =
- OneForOneStrategy() {
- case x: Exception =>
- logError("eventProcesserActor failed; shutting down SparkContext", x)
- try {
- dagScheduler.doCancelAllJobs()
- } catch {
- case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
- }
- dagScheduler.sc.stop()
- Stop
- }
- def receive = {
- case p: Props => sender ! context.actorOf(p)
- case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor")
- }
+ // Start the event thread at the end of the constructor
+ eventProcessLoop.start()
}
-private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler)
- extends Actor with Logging {
-
- override def preStart() {
- // set DAGScheduler for taskScheduler to ensure eventProcessActor is always
- // valid when the messages arrive
- dagScheduler.taskScheduler.setDAGScheduler(dagScheduler)
- }
+private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler)
+ extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging {
/**
* The main event loop of the DAG scheduler.
*/
- def receive = {
+ override def onReceive(event: DAGSchedulerEvent): Unit = event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
listener, properties)
@@ -1382,7 +1351,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.handleExecutorAdded(execId, host)
case ExecutorLost(execId) =>
- dagScheduler.handleExecutorLost(execId)
+ dagScheduler.handleExecutorLost(execId, fetchFailed = false)
case BeginEvent(task, taskInfo) =>
dagScheduler.handleBeginEvent(task, taskInfo)
@@ -1400,7 +1369,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.resubmitFailedStages()
}
- override def postStop() {
+ override def onError(e: Throwable): Unit = {
+ logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e)
+ try {
+ dagScheduler.doCancelAllJobs()
+ } catch {
+ case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
+ }
+ dagScheduler.sc.stop()
+ }
+
+ override def onStop() {
// Cancel any active jobs in postStop hook
dagScheduler.cleanUpAfterSchedulerStop()
}
@@ -1410,9 +1389,5 @@ private[spark] object DAGScheduler {
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
- val RESUBMIT_TIMEOUT = 200.milliseconds
-
- // The time, in millis, to wake up between polls of the completion queue in order to potentially
- // resubmit failed stages
- val POLL_TIMEOUT = 10L
+ val RESUBMIT_TIMEOUT = 200
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 100c9ba9b7809..30075c172bdb1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -17,20 +17,23 @@
package org.apache.spark.scheduler
+import java.io._
+import java.net.URI
+
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import com.google.common.base.Charsets
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path}
import org.apache.hadoop.fs.permission.FsPermission
import org.json4s.JsonAST.JValue
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.{Logging, SparkConf, SPARK_VERSION}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
-import org.apache.spark.SPARK_VERSION
-import org.apache.spark.util.{FileLogger, JsonProtocol, Utils}
+import org.apache.spark.util.{JsonProtocol, Utils}
/**
* A SparkListener that logs events to persistent storage.
@@ -58,36 +61,78 @@ private[spark] class EventLoggingListener(
private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false)
private val testing = sparkConf.getBoolean("spark.eventLog.testing", false)
private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024
- val logDir = EventLoggingListener.getLogDirPath(logBaseDir, appId)
- val logDirName: String = logDir.split("/").last
- protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize,
- shouldCompress, shouldOverwrite, Some(LOG_FILE_PERMISSIONS))
+ private val fileSystem = Utils.getHadoopFileSystem(new URI(logBaseDir), hadoopConf)
+
+ // Only defined if the file system scheme is not local
+ private var hadoopDataStream: Option[FSDataOutputStream] = None
+
+ // The Hadoop APIs have changed over time, so we use reflection to figure out
+ // the correct method to use to flush a hadoop data stream. See SPARK-1518
+ // for details.
+ private val hadoopFlushMethod = {
+ val cls = classOf[FSDataOutputStream]
+ scala.util.Try(cls.getMethod("hflush")).getOrElse(cls.getMethod("sync"))
+ }
+
+ private var writer: Option[PrintWriter] = None
// For testing. Keep track of all JSON serialized events that have been logged.
private[scheduler] val loggedEvents = new ArrayBuffer[JValue]
+ // Visible for tests only.
+ private[scheduler] val logPath = getLogPath(logBaseDir, appId)
+
/**
- * Begin logging events.
- * If compression is used, log a file that indicates which compression library is used.
+ * Creates the log file in the configured log directory.
*/
def start() {
- logger.start()
- logInfo("Logging events to %s".format(logDir))
- if (shouldCompress) {
- val codec =
- sparkConf.get("spark.io.compression.codec", CompressionCodec.DEFAULT_COMPRESSION_CODEC)
- logger.newFile(COMPRESSION_CODEC_PREFIX + codec)
+ if (!fileSystem.isDirectory(new Path(logBaseDir))) {
+ throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.")
}
- logger.newFile(SPARK_VERSION_PREFIX + SPARK_VERSION)
- logger.newFile(LOG_PREFIX + logger.fileIndex)
+
+ val workingPath = logPath + IN_PROGRESS
+ val uri = new URI(workingPath)
+ val path = new Path(workingPath)
+ val defaultFs = FileSystem.getDefaultUri(hadoopConf).getScheme
+ val isDefaultLocal = defaultFs == null || defaultFs == "file"
+
+ if (shouldOverwrite && fileSystem.exists(path)) {
+ logWarning(s"Event log $path already exists. Overwriting...")
+ fileSystem.delete(path, true)
+ }
+
+ /* The Hadoop LocalFileSystem (r1.0.4) has known issues with syncing (HADOOP-7844).
+ * Therefore, for local files, use FileOutputStream instead. */
+ val dstream =
+ if ((isDefaultLocal && uri.getScheme == null) || uri.getScheme == "file") {
+ new FileOutputStream(uri.getPath)
+ } else {
+ hadoopDataStream = Some(fileSystem.create(path))
+ hadoopDataStream.get
+ }
+
+ val compressionCodec =
+ if (shouldCompress) {
+ Some(CompressionCodec.createCodec(sparkConf))
+ } else {
+ None
+ }
+
+ fileSystem.setPermission(path, LOG_FILE_PERMISSIONS)
+ val logStream = initEventLog(new BufferedOutputStream(dstream, outputBufferSize),
+ compressionCodec)
+ writer = Some(new PrintWriter(logStream))
+
+ logInfo("Logging events to %s".format(logPath))
}
/** Log the event as JSON. */
private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) {
val eventJson = JsonProtocol.sparkEventToJson(event)
- logger.logLine(compact(render(eventJson)))
+ writer.foreach(_.println(compact(render(eventJson))))
if (flushLogger) {
- logger.flush()
+ writer.foreach(_.flush())
+ hadoopDataStream.foreach(hadoopFlushMethod.invoke(_))
}
if (testing) {
loggedEvents += eventJson
@@ -123,130 +168,168 @@ private[spark] class EventLoggingListener(
logEvent(event, flushLogger = true)
override def onApplicationEnd(event: SparkListenerApplicationEnd) =
logEvent(event, flushLogger = true)
+ override def onExecutorAdded(event: SparkListenerExecutorAdded) =
+ logEvent(event, flushLogger = true)
+ override def onExecutorRemoved(event: SparkListenerExecutorRemoved) =
+ logEvent(event, flushLogger = true)
+
// No-op because logging every update would be overkill
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { }
/**
- * Stop logging events.
- * In addition, create an empty special file to indicate application completion.
+ * Stop logging events. The event log file will be renamed so that it loses the
+ * ".inprogress" suffix.
*/
def stop() = {
- logger.newFile(APPLICATION_COMPLETE)
- logger.stop()
+ writer.foreach(_.close())
+
+ val target = new Path(logPath)
+ if (fileSystem.exists(target)) {
+ if (shouldOverwrite) {
+ logWarning(s"Event log $target already exists. Overwriting...")
+ fileSystem.delete(target, true)
+ } else {
+ throw new IOException("Target log file already exists (%s)".format(logPath))
+ }
+ }
+ fileSystem.rename(new Path(logPath + IN_PROGRESS), target)
}
+
}
private[spark] object EventLoggingListener extends Logging {
+ // Suffix applied to the names of files still being written by applications.
+ val IN_PROGRESS = ".inprogress"
val DEFAULT_LOG_DIR = "/tmp/spark-events"
- val LOG_PREFIX = "EVENT_LOG_"
- val SPARK_VERSION_PREFIX = "SPARK_VERSION_"
- val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_"
- val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
- val LOG_FILE_PERMISSIONS = FsPermission.createImmutable(Integer.parseInt("770", 8).toShort)
- // A cache for compression codecs to avoid creating the same codec many times
- private val codecMap = new mutable.HashMap[String, CompressionCodec]
+ private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
- def isEventLogFile(fileName: String): Boolean = {
- fileName.startsWith(LOG_PREFIX)
- }
+ // Marker for the end of header data in a log file. After this marker, log data, potentially
+ // compressed, will be found.
+ private val HEADER_END_MARKER = "=== LOG_HEADER_END ==="
- def isSparkVersionFile(fileName: String): Boolean = {
- fileName.startsWith(SPARK_VERSION_PREFIX)
- }
+ // To avoid corrupted files causing the heap to fill up. Value is arbitrary.
+ private val MAX_HEADER_LINE_LENGTH = 4096
- def isCompressionCodecFile(fileName: String): Boolean = {
- fileName.startsWith(COMPRESSION_CODEC_PREFIX)
- }
+ // A cache for compression codecs to avoid creating the same codec many times
+ private val codecMap = new mutable.HashMap[String, CompressionCodec]
- def isApplicationCompleteFile(fileName: String): Boolean = {
- fileName == APPLICATION_COMPLETE
- }
+ /**
+ * Write metadata about the event log to the given stream.
+ *
+ * The header is a serialized version of a map, except it does not use Java serialization to
+ * avoid incompatibilities between different JDKs. It writes one map entry per line, in
+ * "key=value" format.
+ *
+ * The very last entry in the header is the `HEADER_END_MARKER` marker, so that the parsing code
+ * can know when to stop.
+ *
+ * The format needs to be kept in sync with the openEventLog() method below. Also, it cannot
+ * change in new Spark versions without some other way of detecting the change (like some
+ * metadata encoded in the file name).
+ *
+ * @param logStream Raw output stream to the even log file.
+ * @param compressionCodec Optional compression codec to use.
+ * @return A stream where to write event log data. This may be a wrapper around the original
+ * stream (for example, when compression is enabled).
+ */
+ def initEventLog(
+ logStream: OutputStream,
+ compressionCodec: Option[CompressionCodec]): OutputStream = {
+ val meta = mutable.HashMap(("version" -> SPARK_VERSION))
+ compressionCodec.foreach { codec =>
+ meta += ("compressionCodec" -> codec.getClass().getName())
+ }
- def parseSparkVersion(fileName: String): String = {
- if (isSparkVersionFile(fileName)) {
- fileName.replaceAll(SPARK_VERSION_PREFIX, "")
- } else ""
- }
+ def write(entry: String) = {
+ val bytes = entry.getBytes(Charsets.UTF_8)
+ if (bytes.length > MAX_HEADER_LINE_LENGTH) {
+ throw new IOException(s"Header entry too long: ${entry}")
+ }
+ logStream.write(bytes, 0, bytes.length)
+ }
- def parseCompressionCodec(fileName: String): String = {
- if (isCompressionCodecFile(fileName)) {
- fileName.replaceAll(COMPRESSION_CODEC_PREFIX, "")
- } else ""
+ meta.foreach { case (k, v) => write(s"$k=$v\n") }
+ write(s"$HEADER_END_MARKER\n")
+ compressionCodec.map(_.compressedOutputStream(logStream)).getOrElse(logStream)
}
/**
- * Return a file-system-safe path to the log directory for the given application.
+ * Return a file-system-safe path to the log file for the given application.
*
- * @param logBaseDir A base directory for the path to the log directory for given application.
+ * @param logBaseDir Directory where the log file will be written.
* @param appId A unique app ID.
* @return A path which consists of file-system-safe characters.
*/
- def getLogDirPath(logBaseDir: String, appId: String): String = {
+ def getLogPath(logBaseDir: String, appId: String): String = {
val name = appId.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_").toLowerCase
Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/")
}
/**
- * Parse the event logging information associated with the logs in the given directory.
+ * Opens an event log file and returns an input stream to the event data.
*
- * Specifically, this looks for event log files, the Spark version file, the compression
- * codec file (if event logs are compressed), and the application completion file (if the
- * application has run to completion).
+ * @return 2-tuple (event input stream, Spark version of event data)
*/
- def parseLoggingInfo(logDir: Path, fileSystem: FileSystem): EventLoggingInfo = {
+ def openEventLog(log: Path, fs: FileSystem): (InputStream, String) = {
+ // It's not clear whether FileSystem.open() throws FileNotFoundException or just plain
+ // IOException when a file does not exist, so try our best to throw a proper exception.
+ if (!fs.exists(log)) {
+ throw new FileNotFoundException(s"File $log does not exist.")
+ }
+
+ val in = new BufferedInputStream(fs.open(log))
+ // Read a single line from the input stream without buffering.
+ // We cannot use BufferedReader because we must avoid reading
+ // beyond the end of the header, after which the content of the
+ // file may be compressed.
+ def readLine(): String = {
+ val bytes = new ByteArrayOutputStream()
+ var next = in.read()
+ var count = 0
+ while (next != '\n') {
+ if (next == -1) {
+ throw new IOException("Unexpected end of file.")
+ }
+ bytes.write(next)
+ count = count + 1
+ if (count > MAX_HEADER_LINE_LENGTH) {
+ throw new IOException("Maximum header line length exceeded.")
+ }
+ next = in.read()
+ }
+ new String(bytes.toByteArray(), Charsets.UTF_8)
+ }
+
+ // Parse the header metadata in the form of k=v pairs
+ // This assumes that every line before the header end marker follows this format
try {
- val fileStatuses = fileSystem.listStatus(logDir)
- val filePaths =
- if (fileStatuses != null) {
- fileStatuses.filter(!_.isDir).map(_.getPath).toSeq
- } else {
- Seq[Path]()
+ val meta = new mutable.HashMap[String, String]()
+ var foundEndMarker = false
+ while (!foundEndMarker) {
+ readLine() match {
+ case HEADER_END_MARKER =>
+ foundEndMarker = true
+ case entry =>
+ val prop = entry.split("=", 2)
+ if (prop.length != 2) {
+ throw new IllegalArgumentException("Invalid metadata in log file.")
+ }
+ meta += (prop(0) -> prop(1))
}
- if (filePaths.isEmpty) {
- logWarning("No files found in logging directory %s".format(logDir))
}
- EventLoggingInfo(
- logPaths = filePaths.filter { path => isEventLogFile(path.getName) },
- sparkVersion = filePaths
- .find { path => isSparkVersionFile(path.getName) }
- .map { path => parseSparkVersion(path.getName) }
- .getOrElse(""),
- compressionCodec = filePaths
- .find { path => isCompressionCodecFile(path.getName) }
- .map { path =>
- val codec = EventLoggingListener.parseCompressionCodec(path.getName)
- val conf = new SparkConf
- conf.set("spark.io.compression.codec", codec)
- codecMap.getOrElseUpdate(codec, CompressionCodec.createCodec(conf))
- },
- applicationComplete = filePaths.exists { path => isApplicationCompleteFile(path.getName) }
- )
+
+ val sparkVersion = meta.get("version").getOrElse(
+ throw new IllegalArgumentException("Missing Spark version in log metadata."))
+ val codec = meta.get("compressionCodec").map { codecName =>
+ codecMap.getOrElseUpdate(codecName, CompressionCodec.createCodec(new SparkConf, codecName))
+ }
+ (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion)
} catch {
case e: Exception =>
- logError("Exception in parsing logging info from directory %s".format(logDir), e)
- EventLoggingInfo.empty
+ in.close()
+ throw e
}
}
- /**
- * Parse the event logging information associated with the logs in the given directory.
- */
- def parseLoggingInfo(logDir: String, fileSystem: FileSystem): EventLoggingInfo = {
- parseLoggingInfo(new Path(logDir), fileSystem)
- }
-}
-
-
-/**
- * Information needed to process the event logs associated with an application.
- */
-private[spark] case class EventLoggingInfo(
- logPaths: Seq[Path],
- sparkVersion: String,
- compressionCodec: Option[CompressionCodec],
- applicationComplete: Boolean = false)
-
-private[spark] object EventLoggingInfo {
- def empty = EventLoggingInfo(Seq[Path](), "", None, applicationComplete = false)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 54904bffdf10b..3bb54855bae44 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -158,6 +158,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
" INPUT_BYTES=" + metrics.bytesRead
case None => ""
}
+ val outputMetrics = taskMetrics.outputMetrics match {
+ case Some(metrics) =>
+ " OUTPUT_BYTES=" + metrics.bytesWritten
+ case None => ""
+ }
val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
@@ -173,7 +178,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
" SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime
case None => ""
}
- stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics +
+ stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics +
shuffleReadMetrics + writeMetrics)
}
@@ -215,7 +220,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
" STAGE_ID=" + taskEnd.stageId
stageLogInfo(taskEnd.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) =>
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
index 36a6e6338faa6..be23056e7d423 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
@@ -17,10 +17,9 @@
package org.apache.spark.scheduler
-import java.util.concurrent.{LinkedBlockingQueue, Semaphore}
+import java.util.concurrent.atomic.AtomicBoolean
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.util.AsynchronousListenerBus
/**
* Asynchronously passes SparkListenerEvents to registered SparkListeners.
@@ -29,113 +28,19 @@ import org.apache.spark.util.Utils
* has started will events be actually propagated to all attached listeners. This listener bus
* is stopped when it receives a SparkListenerShutdown event, which is posted using stop().
*/
-private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
-
- /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
- * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
- private val EVENT_QUEUE_CAPACITY = 10000
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY)
- private var queueFullErrorMessageLogged = false
- private var started = false
-
- // A counter that represents the number of events produced and consumed in the queue
- private val eventLock = new Semaphore(0)
-
- private val listenerThread = new Thread("SparkListenerBus") {
- setDaemon(true)
- override def run(): Unit = Utils.logUncaughtExceptions {
- while (true) {
- eventLock.acquire()
- // Atomically remove and process this event
- LiveListenerBus.this.synchronized {
- val event = eventQueue.poll
- if (event == SparkListenerShutdown) {
- // Get out of the while loop and shutdown the daemon thread
- return
- }
- Option(event).foreach(postToAll)
- }
- }
- }
- }
-
- /**
- * Start sending events to attached listeners.
- *
- * This first sends out all buffered events posted before this listener bus has started, then
- * listens for any additional events asynchronously while the listener bus is still running.
- * This should only be called once.
- */
- def start() {
- if (started) {
- throw new IllegalStateException("Listener bus already started!")
+private[spark] class LiveListenerBus
+ extends AsynchronousListenerBus[SparkListener, SparkListenerEvent]("SparkListenerBus")
+ with SparkListenerBus {
+
+ private val logDroppedEvent = new AtomicBoolean(false)
+
+ override def onDropEvent(event: SparkListenerEvent): Unit = {
+ if (logDroppedEvent.compareAndSet(false, true)) {
+ // Only log the following message once to avoid duplicated annoying logs.
+ logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
+ "This likely means one of the SparkListeners is too slow and cannot keep up with " +
+ "the rate at which tasks are being started by the scheduler.")
}
- listenerThread.start()
- started = true
}
- def post(event: SparkListenerEvent) {
- val eventAdded = eventQueue.offer(event)
- if (eventAdded) {
- eventLock.release()
- } else {
- logQueueFullErrorMessage()
- }
- }
-
- /**
- * For testing only. Wait until there are no more events in the queue, or until the specified
- * time has elapsed. Return true if the queue has emptied and false is the specified time
- * elapsed before the queue emptied.
- */
- def waitUntilEmpty(timeoutMillis: Int): Boolean = {
- val finishTime = System.currentTimeMillis + timeoutMillis
- while (!queueIsEmpty) {
- if (System.currentTimeMillis > finishTime) {
- return false
- }
- /* Sleep rather than using wait/notify, because this is used only for testing and
- * wait/notify add overhead in the general case. */
- Thread.sleep(10)
- }
- true
- }
-
- /**
- * For testing only. Return whether the listener daemon thread is still alive.
- */
- def listenerThreadIsAlive: Boolean = synchronized { listenerThread.isAlive }
-
- /**
- * Return whether the event queue is empty.
- *
- * The use of synchronized here guarantees that all events that once belonged to this queue
- * have already been processed by all attached listeners, if this returns true.
- */
- def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty }
-
- /**
- * Log an error message to indicate that the event queue is full. Do this only once.
- */
- private def logQueueFullErrorMessage(): Unit = {
- if (!queueFullErrorMessageLogged) {
- if (listenerThread.isAlive) {
- logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
- "This likely means one of the SparkListeners is too slow and cannot keep up with" +
- "the rate at which tasks are being started by the scheduler.")
- } else {
- logError("SparkListenerBus thread is dead! This means SparkListenerEvents have not" +
- "been (and will no longer be) propagated to listeners for some time.")
- }
- queueFullErrorMessageLogged = true
- }
- }
-
- def stop() {
- if (!started) {
- throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!")
- }
- post(SparkListenerShutdown)
- listenerThread.join()
- }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index e25096ea92d70..1efce124c0a6b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,7 +19,10 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import org.roaringbitmap.RoaringBitmap
+
import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
/**
* Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
@@ -29,7 +32,12 @@ private[spark] sealed trait MapStatus {
/** Location where this task was run. */
def location: BlockManagerId
- /** Estimated size for the reduce block, in bytes. */
+ /**
+ * Estimated size for the reduce block, in bytes.
+ *
+ * If a block is non-empty, then this method MUST return a non-zero size. This invariant is
+ * necessary for correctness, since block fetchers are allowed to skip zero-size blocks.
+ */
def getSizeForBlock(reduceId: Int): Long
}
@@ -38,7 +46,7 @@ private[spark] object MapStatus {
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
if (uncompressedSizes.length > 2000) {
- new HighlyCompressedMapStatus(loc, uncompressedSizes)
+ HighlyCompressedMapStatus(loc, uncompressedSizes)
} else {
new CompressedMapStatus(loc, uncompressedSizes)
}
@@ -98,13 +106,13 @@ private[spark] class CompressedMapStatus(
MapStatus.decompressSize(compressedSizes(reduceId))
}
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
val len = in.readInt()
compressedSizes = new Array[Byte](len)
@@ -112,35 +120,80 @@ private[spark] class CompressedMapStatus(
}
}
-
/**
- * A [[MapStatus]] implementation that only stores the average size of the blocks.
+ * A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
+ * plus a bitmap for tracking which blocks are empty. During serialization, this bitmap
+ * is compressed.
*
- * @param loc location where the task is being executed.
- * @param avgSize average size of all the blocks
+ * @param loc location where the task is being executed
+ * @param numNonEmptyBlocks the number of non-empty blocks
+ * @param emptyBlocks a bitmap tracking which blocks are empty
+ * @param avgSize average size of the non-empty blocks
*/
-private[spark] class HighlyCompressedMapStatus(
+private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
+ private[this] var numNonEmptyBlocks: Int,
+ private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long)
extends MapStatus with Externalizable {
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
- this(loc, uncompressedSizes.sum / uncompressedSizes.length)
- }
+ // loc could be null when the default constructor is called during deserialization
+ require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0,
+ "Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, 0L) // For deserialization only
+ protected def this() = this(null, -1, null, -1) // For deserialization only
override def location: BlockManagerId = loc
- override def getSizeForBlock(reduceId: Int): Long = avgSize
+ override def getSizeForBlock(reduceId: Int): Long = {
+ if (emptyBlocks.contains(reduceId)) {
+ 0
+ } else {
+ avgSize
+ }
+ }
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
+ emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
+ emptyBlocks = new RoaringBitmap()
+ emptyBlocks.readExternal(in)
avgSize = in.readLong()
}
}
+
+private[spark] object HighlyCompressedMapStatus {
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ // We must keep track of which blocks are empty so that we don't report a zero-sized
+ // block as being non-empty (or vice-versa) when using the average block size.
+ var i = 0
+ var numNonEmptyBlocks: Int = 0
+ var totalSize: Long = 0
+ // From a compression standpoint, it shouldn't matter whether we track empty or non-empty
+ // blocks. From a performance standpoint, we benefit from tracking empty blocks because
+ // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
+ val emptyBlocks = new RoaringBitmap()
+ val totalNumBlocks = uncompressedSizes.length
+ while (i < totalNumBlocks) {
+ var size = uncompressedSizes(i)
+ if (size > 0) {
+ numNonEmptyBlocks += 1
+ totalSize += size
+ } else {
+ emptyBlocks.add(i)
+ }
+ i += 1
+ }
+ val avgSize = if (numNonEmptyBlocks > 0) {
+ totalSize / numNonEmptyBlocks
+ } else {
+ 0
+ }
+ new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
index f89724d4ea196..584f4e7789d1a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
@@ -17,74 +17,45 @@
package org.apache.spark.scheduler
-import java.io.{BufferedInputStream, InputStream}
+import java.io.{InputStream, IOException}
import scala.io.Source
-import org.apache.hadoop.fs.{Path, FileSystem}
import org.json4s.jackson.JsonMethods._
import org.apache.spark.Logging
-import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.JsonProtocol
/**
- * A SparkListenerBus that replays logged events from persisted storage.
- *
- * This assumes the given paths are valid log files, where each line can be deserialized into
- * exactly one SparkListenerEvent.
+ * A SparkListenerBus that can be used to replay events from serialized event data.
*/
-private[spark] class ReplayListenerBus(
- logPaths: Seq[Path],
- fileSystem: FileSystem,
- compressionCodec: Option[CompressionCodec])
- extends SparkListenerBus with Logging {
-
- private var replayed = false
-
- if (logPaths.length == 0) {
- logWarning("Log path provided contains no log files.")
- }
+private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
/**
- * Replay each event in the order maintained in the given logs.
- * This should only be called exactly once.
+ * Replay each event in the order maintained in the given stream. The stream is expected to
+ * contain one JSON-encoded SparkListenerEvent per line.
+ *
+ * This method can be called multiple times, but the listener behavior is undefined after any
+ * error is thrown by this method.
+ *
+ * @param logData Stream containing event log data.
+ * @param version Spark version that generated the events.
*/
- def replay() {
- assert(!replayed, "ReplayListenerBus cannot replay events more than once")
- logPaths.foreach { path =>
- // Keep track of input streams at all levels to close them later
- // This is necessary because an exception can occur in between stream initializations
- var fileStream: Option[InputStream] = None
- var bufferedStream: Option[InputStream] = None
- var compressStream: Option[InputStream] = None
- var currentLine = ""
- try {
- fileStream = Some(fileSystem.open(path))
- bufferedStream = Some(new BufferedInputStream(fileStream.get))
- compressStream = Some(wrapForCompression(bufferedStream.get))
-
- // Parse each line as an event and post the event to all attached listeners
- val lines = Source.fromInputStream(compressStream.get).getLines()
- lines.foreach { line =>
- currentLine = line
- postToAll(JsonProtocol.sparkEventFromJson(parse(line)))
- }
- } catch {
- case e: Exception =>
- logError("Exception in parsing Spark event log %s".format(path), e)
- logError("Malformed line: %s\n".format(currentLine))
- } finally {
- fileStream.foreach(_.close())
- bufferedStream.foreach(_.close())
- compressStream.foreach(_.close())
+ def replay(logData: InputStream, version: String) {
+ var currentLine: String = null
+ try {
+ val lines = Source.fromInputStream(logData).getLines()
+ lines.foreach { line =>
+ currentLine = line
+ postToAll(JsonProtocol.sparkEventFromJson(parse(line)))
}
+ } catch {
+ case ioe: IOException =>
+ throw ioe
+ case e: Exception =>
+ logError("Exception in parsing Spark event log.", e)
+ logError("Malformed line: %s\n".format(currentLine))
}
- replayed = true
}
- /** If a compression codec is specified, wrap the given stream in a compression stream. */
- private def wrapForCompression(stream: InputStream): InputStream = {
- compressionCodec.map(_.compressedInputStream(stream)).getOrElse(stream)
- }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 86afe3bd5265f..dd28ddb31de1f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -25,6 +25,7 @@ import scala.collection.mutable
import org.apache.spark.{Logging, TaskEndReason}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{Distribution, Utils}
@@ -56,11 +57,23 @@ case class SparkListenerTaskEnd(
extends SparkListenerEvent
@DeveloperApi
-case class SparkListenerJobStart(jobId: Int, stageIds: Seq[Int], properties: Properties = null)
- extends SparkListenerEvent
+case class SparkListenerJobStart(
+ jobId: Int,
+ time: Long,
+ stageInfos: Seq[StageInfo],
+ properties: Properties = null)
+ extends SparkListenerEvent {
+ // Note: this is here for backwards-compatibility with older versions of this event which
+ // only stored stageIds and not StageInfos:
+ val stageIds: Seq[Int] = stageInfos.map(_.stageId)
+}
@DeveloperApi
-case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent
+case class SparkListenerJobEnd(
+ jobId: Int,
+ time: Long,
+ jobResult: JobResult)
+ extends SparkListenerEvent
@DeveloperApi
case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(String, String)]])
@@ -77,6 +90,14 @@ case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockMan
@DeveloperApi
case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent
+@DeveloperApi
+case class SparkListenerExecutorAdded(time: Long, executorId: String, executorInfo: ExecutorInfo)
+ extends SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String)
+ extends SparkListenerEvent
+
/**
* Periodic updates from executors.
* @param execId executor id
@@ -95,14 +116,12 @@ case class SparkListenerApplicationStart(appName: String, appId: Option[String],
@DeveloperApi
case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent
-/** An event used in the listener to shutdown the listener daemon thread. */
-private[spark] case object SparkListenerShutdown extends SparkListenerEvent
-
/**
* :: DeveloperApi ::
* Interface for listening to events from the Spark scheduler. Note that this is an internal
- * interface which might change in different Spark releases.
+ * interface which might change in different Spark releases. Java clients should extend
+ * {@link JavaSparkListener}
*/
@DeveloperApi
trait SparkListener {
@@ -176,6 +195,16 @@ trait SparkListener {
* Called when the driver receives task metrics from an executor in a heartbeat.
*/
def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { }
+
+ /**
+ * Called when the driver registers a new executor.
+ */
+ def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) { }
+
+ /**
+ * Called when the driver removes an executor.
+ */
+ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index e79ffd7a3587d..fe8a19a2c0cb9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -17,74 +17,47 @@
package org.apache.spark.scheduler
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.util.ListenerBus
/**
- * A SparkListenerEvent bus that relays events to its listeners
+ * A [[SparkListenerEvent]] bus that relays [[SparkListenerEvent]]s to its listeners
*/
-private[spark] trait SparkListenerBus extends Logging {
-
- // SparkListeners attached to this event bus
- protected val sparkListeners = new ArrayBuffer[SparkListener]
- with mutable.SynchronizedBuffer[SparkListener]
-
- def addListener(listener: SparkListener) {
- sparkListeners += listener
- }
+private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] {
- /**
- * Post an event to all attached listeners.
- * This does nothing if the event is SparkListenerShutdown.
- */
- def postToAll(event: SparkListenerEvent) {
+ override def onPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = {
event match {
case stageSubmitted: SparkListenerStageSubmitted =>
- foreachListener(_.onStageSubmitted(stageSubmitted))
+ listener.onStageSubmitted(stageSubmitted)
case stageCompleted: SparkListenerStageCompleted =>
- foreachListener(_.onStageCompleted(stageCompleted))
+ listener.onStageCompleted(stageCompleted)
case jobStart: SparkListenerJobStart =>
- foreachListener(_.onJobStart(jobStart))
+ listener.onJobStart(jobStart)
case jobEnd: SparkListenerJobEnd =>
- foreachListener(_.onJobEnd(jobEnd))
+ listener.onJobEnd(jobEnd)
case taskStart: SparkListenerTaskStart =>
- foreachListener(_.onTaskStart(taskStart))
+ listener.onTaskStart(taskStart)
case taskGettingResult: SparkListenerTaskGettingResult =>
- foreachListener(_.onTaskGettingResult(taskGettingResult))
+ listener.onTaskGettingResult(taskGettingResult)
case taskEnd: SparkListenerTaskEnd =>
- foreachListener(_.onTaskEnd(taskEnd))
+ listener.onTaskEnd(taskEnd)
case environmentUpdate: SparkListenerEnvironmentUpdate =>
- foreachListener(_.onEnvironmentUpdate(environmentUpdate))
+ listener.onEnvironmentUpdate(environmentUpdate)
case blockManagerAdded: SparkListenerBlockManagerAdded =>
- foreachListener(_.onBlockManagerAdded(blockManagerAdded))
+ listener.onBlockManagerAdded(blockManagerAdded)
case blockManagerRemoved: SparkListenerBlockManagerRemoved =>
- foreachListener(_.onBlockManagerRemoved(blockManagerRemoved))
+ listener.onBlockManagerRemoved(blockManagerRemoved)
case unpersistRDD: SparkListenerUnpersistRDD =>
- foreachListener(_.onUnpersistRDD(unpersistRDD))
+ listener.onUnpersistRDD(unpersistRDD)
case applicationStart: SparkListenerApplicationStart =>
- foreachListener(_.onApplicationStart(applicationStart))
+ listener.onApplicationStart(applicationStart)
case applicationEnd: SparkListenerApplicationEnd =>
- foreachListener(_.onApplicationEnd(applicationEnd))
+ listener.onApplicationEnd(applicationEnd)
case metricsUpdate: SparkListenerExecutorMetricsUpdate =>
- foreachListener(_.onExecutorMetricsUpdate(metricsUpdate))
- case SparkListenerShutdown =>
- }
- }
-
- /**
- * Apply the given function to all attached listeners, catching and logging any exception.
- */
- private def foreachListener(f: SparkListener => Unit): Unit = {
- sparkListeners.foreach { listener =>
- try {
- f(listener)
- } catch {
- case e: Exception =>
- logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e)
- }
+ listener.onExecutorMetricsUpdate(metricsUpdate)
+ case executorAdded: SparkListenerExecutorAdded =>
+ listener.onExecutorAdded(executorAdded)
+ case executorRemoved: SparkListenerExecutorRemoved =>
+ listener.onExecutorRemoved(executorRemoved)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 071568cdfb429..cc13f57a49b89 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -102,6 +102,11 @@ private[spark] class Stage(
}
}
+ /**
+ * Removes all shuffle outputs associated with this executor. Note that this will also remove
+ * outputs which are served by an external shuffle server (if one exists), as they are still
+ * registered with this execId.
+ */
def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
@@ -131,4 +136,9 @@ private[spark] class Stage(
override def toString = "Stage " + id
override def hashCode(): Int = id
+
+ override def equals(other: Any): Boolean = other match {
+ case stage: Stage => stage != null && stage.id == id
+ case _ => false
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 2552d03d18d06..847a4912eec13 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -44,10 +44,18 @@ import org.apache.spark.util.Utils
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
- final def run(attemptId: Long): T = {
- context = new TaskContextImpl(stageId, partitionId, attemptId, false)
+ /**
+ * Called by Executor to run this task.
+ *
+ * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
+ * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
+ * @return the result of the task
+ */
+ final def run(taskAttemptId: Long, attemptNumber: Int): T = {
+ context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
+ taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
TaskContextHelper.setTaskContext(context)
- context.taskMetrics.hostname = Utils.localHostName()
+ context.taskMetrics.setHostname(Utils.localHostName())
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 4c96b9e5fef60..1c7c81c488c3a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -27,6 +27,7 @@ import org.apache.spark.util.SerializableBuffer
*/
private[spark] class TaskDescription(
val taskId: Long,
+ val attemptNumber: Int,
val executorId: String,
val name: String,
val index: Int, // Index within this task's TaskSet
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index d49d8fb887007..1f114a0207f7b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -31,8 +31,8 @@ import org.apache.spark.util.Utils
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
-private[spark]
-case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
+private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
+ extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
@@ -42,7 +42,7 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
def this() = this(null.asInstanceOf[ByteBuffer], null, null)
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(valueBytes.remaining);
Utils.writeByteBuffer(valueBytes, out)
@@ -55,7 +55,7 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
out.writeObject(metrics)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val blen = in.readInt()
val byteVal = new Array[Byte](blen)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 3f345ceeaaf7a..774f3d8cdb275 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import scala.language.existentials
import scala.util.control.NonFatal
import org.apache.spark._
@@ -47,9 +48,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
try {
- val result = serializer.get().deserialize[TaskResult[_]](serializedData) match {
- case directResult: DirectTaskResult[_] => directResult
- case IndirectTaskResult(blockId) =>
+ val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
+ case directResult: DirectTaskResult[_] =>
+ if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
+ return
+ }
+ (directResult, serializedData.limit())
+ case IndirectTaskResult(blockId, size) =>
+ if (!taskSetManager.canFetchMoreResults(size)) {
+ // dropped by executor if size is larger than maxResultSize
+ sparkEnv.blockManager.master.removeBlock(blockId)
+ return
+ }
logDebug("Fetching indirect task result for TID %s".format(tid))
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
@@ -64,9 +74,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
serializedTaskResult.get)
sparkEnv.blockManager.master.removeBlock(blockId)
- deserializedResult
+ (deserializedResult, size)
}
- result.metrics.resultSize = serializedData.limit()
+
+ result.metrics.setResultSize(size)
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
} catch {
case cnf: ClassNotFoundException =>
@@ -93,7 +104,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
}
} catch {
case cnd: ClassNotFoundException =>
- // Log an error but keep going here -- the task failed, so not catastropic if we can't
+ // Log an error but keep going here -- the task failed, so not catastrophic if we can't
// deserialize the reason.
val loader = Utils.getContextOrSparkClassLoader
logError(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index a129a434c9a1a..f095915352b17 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockManagerId
/**
* Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl.
- * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
+ * This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks
* for a single SparkContext. These schedulers get sets of tasks submitted to them from the
* DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
* them, retrying if there are failures, and mitigating stragglers. They return events to the
@@ -41,7 +41,7 @@ private[spark] trait TaskScheduler {
// Invoked after system has successfully initialized (typically in spark context).
// Yarn uses this to bootstrap allocation of resources based on preferred locations,
- // wait for slave registerations, etc.
+ // wait for slave registrations, etc.
def postStartHook() { }
// Disconnect from the cluster.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 2b39c7fc872da..79f84e70df9d5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -31,10 +31,10 @@ import scala.util.Random
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.TaskLocality.TaskLocality
import org.apache.spark.util.Utils
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
-import akka.actor.Props
/**
* Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
@@ -168,7 +168,7 @@ private[spark] class TaskSchedulerImpl(
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
- "and have sufficient memory")
+ "and have sufficient resources")
} else {
this.cancel()
}
@@ -210,6 +210,40 @@ private[spark] class TaskSchedulerImpl(
.format(manager.taskSet.id, manager.parent.name))
}
+ private def resourceOfferSingleTaskSet(
+ taskSet: TaskSetManager,
+ maxLocality: TaskLocality,
+ shuffledOffers: Seq[WorkerOffer],
+ availableCpus: Array[Int],
+ tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = {
+ var launchedTask = false
+ for (i <- 0 until shuffledOffers.size) {
+ val execId = shuffledOffers(i).executorId
+ val host = shuffledOffers(i).host
+ if (availableCpus(i) >= CPUS_PER_TASK) {
+ try {
+ for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
+ tasks(i) += task
+ val tid = task.taskId
+ taskIdToTaskSetId(tid) = taskSet.taskSet.id
+ taskIdToExecutorId(tid) = execId
+ executorsByHost(host) += execId
+ availableCpus(i) -= CPUS_PER_TASK
+ assert(availableCpus(i) >= 0)
+ launchedTask = true
+ }
+ } catch {
+ case e: TaskNotSerializableException =>
+ logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
+ // Do not offer resources for this task, but don't throw an error to allow other
+ // task sets to be submitted.
+ return launchedTask
+ }
+ }
+ }
+ return launchedTask
+ }
+
/**
* Called by cluster manager to offer resources on slaves. We respond by asking our active task
* sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
@@ -252,23 +286,8 @@ private[spark] class TaskSchedulerImpl(
var launchedTask = false
for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
do {
- launchedTask = false
- for (i <- 0 until shuffledOffers.size) {
- val execId = shuffledOffers(i).executorId
- val host = shuffledOffers(i).host
- if (availableCpus(i) >= CPUS_PER_TASK) {
- for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
- tasks(i) += task
- val tid = task.taskId
- taskIdToTaskSetId(tid) = taskSet.taskSet.id
- taskIdToExecutorId(tid) = execId
- executorsByHost(host) += execId
- availableCpus(i) -= CPUS_PER_TASK
- assert(availableCpus(i) >= 0)
- launchedTask = true
- }
- }
- }
+ launchedTask = resourceOfferSingleTaskSet(
+ taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
} while (launchedTask)
}
@@ -342,7 +361,7 @@ private[spark] class TaskSchedulerImpl(
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
}
- def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
+ def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized {
taskSetManager.handleTaskGettingResult(tid)
}
@@ -395,9 +414,6 @@ private[spark] class TaskSchedulerImpl(
taskResultGetter.stop()
}
starvationTimer.cancel()
-
- // sleeping for an arbitrary 1 seconds to ensure that messages are sent out.
- Thread.sleep(1000L)
}
override def defaultParallelism() = backend.defaultParallelism()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index a6c23fc85a1b0..55024ecd55e61 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -18,18 +18,19 @@
package org.apache.spark.scheduler
import java.io.NotSerializableException
+import java.nio.ByteBuffer
import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
+import scala.math.{min, max}
+import scala.util.control.NonFatal
import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{Clock, SystemClock}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{Clock, SystemClock, Utils}
/**
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
@@ -68,6 +69,9 @@ private[spark] class TaskSetManager(
val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5)
+ // Limit of bytes for total size of results (default is 1GB)
+ val maxResultSize = Utils.getMaxResultSize(conf)
+
// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
@@ -89,6 +93,8 @@ private[spark] class TaskSetManager(
var stageId = taskSet.stageId
var name = "TaskSet_" + taskSet.stageId.toString
var parent: Pool = null
+ var totalResultSize = 0L
+ var calculatedTasks = 0
val runningTasksSet = new HashSet[Long]
override def runningTasks = runningTasksSet.size
@@ -245,7 +251,7 @@ private[spark] class TaskSetManager(
* This method also cleans up any tasks in the list that have already
* been launched, since we want that to happen lazily.
*/
- private def findTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = {
+ private def dequeueTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = {
var indexOffset = list.size
while (indexOffset > 0) {
indexOffset -= 1
@@ -286,7 +292,7 @@ private[spark] class TaskSetManager(
* an attempt running on this host, in case the host is slow. In addition, the task should meet
* the given locality constraint.
*/
- private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
@@ -362,22 +368,22 @@ private[spark] class TaskSetManager(
*
* @return An option containing (task index within the task set, locality, is speculative?)
*/
- private def findTask(execId: String, host: String, maxLocality: TaskLocality.Value)
+ private def dequeueTask(execId: String, host: String, maxLocality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value, Boolean)] =
{
- for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) {
+ for (index <- dequeueTaskFromList(execId, getPendingTasksForExecutor(execId))) {
return Some((index, TaskLocality.PROCESS_LOCAL, false))
}
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) {
- for (index <- findTaskFromList(execId, getPendingTasksForHost(host))) {
+ for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) {
return Some((index, TaskLocality.NODE_LOCAL, false))
}
}
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) {
// Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic
- for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) {
+ for (index <- dequeueTaskFromList(execId, pendingTasksWithNoPrefs)) {
return Some((index, TaskLocality.PROCESS_LOCAL, false))
}
}
@@ -385,20 +391,20 @@ private[spark] class TaskSetManager(
if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) {
for {
rack <- sched.getRackForHost(host)
- index <- findTaskFromList(execId, getPendingTasksForRack(rack))
+ index <- dequeueTaskFromList(execId, getPendingTasksForRack(rack))
} {
return Some((index, TaskLocality.RACK_LOCAL, false))
}
}
if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) {
- for (index <- findTaskFromList(execId, allPendingTasks)) {
+ for (index <- dequeueTaskFromList(execId, allPendingTasks)) {
return Some((index, TaskLocality.ANY, false))
}
}
// find a speculative task if all others tasks have been scheduled
- findSpeculativeTask(execId, host, maxLocality).map {
+ dequeueSpeculativeTask(execId, host, maxLocality).map {
case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)}
}
@@ -413,6 +419,7 @@ private[spark] class TaskSetManager(
* @param host the host Id of the offered resource
* @param maxLocality the maximum locality we want to schedule the tasks at
*/
+ @throws[TaskNotSerializableException]
def resourceOffer(
execId: String,
host: String,
@@ -432,7 +439,7 @@ private[spark] class TaskSetManager(
}
}
- findTask(execId, host, allowedLocality) match {
+ dequeueTask(execId, host, allowedLocality) match {
case Some((index, taskLocality, speculative)) => {
// Found a task; do some bookkeeping and return a task description
val task = tasks(index)
@@ -452,10 +459,17 @@ private[spark] class TaskSetManager(
}
// Serialize and return the task
val startTime = clock.getTime()
- // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
- // we assume the task can be serialized without exceptions.
- val serializedTask = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val serializedTask: ByteBuffer = try {
+ Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ } catch {
+ // If the task cannot be serialized, then there's no point to re-attempt the task,
+ // as it will always fail. So just abort the whole task-set.
+ case NonFatal(e) =>
+ val msg = s"Failed to serialize task $taskId, not attempting to retry it."
+ logError(msg, e)
+ abort(s"$msg Exception during serialization: $e")
+ throw new TaskNotSerializableException(e)
+ }
if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
!emittedTaskSizeWarning) {
emittedTaskSizeWarning = true
@@ -473,7 +487,8 @@ private[spark] class TaskSetManager(
taskName, taskId, host, taskLocality, serializedTask.limit))
sched.dagScheduler.taskStarted(task, info)
- return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+ return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId,
+ taskName, index, serializedTask))
}
case _ =>
}
@@ -491,13 +506,64 @@ private[spark] class TaskSetManager(
* Get the level we can launch tasks according to delay scheduling, based on current wait time.
*/
private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
- while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
- currentLocalityIndex < myLocalityLevels.length - 1)
- {
- // Jump to the next locality level, and remove our waiting time for the current one since
- // we don't want to count it again on the next one
- lastLaunchTime += localityWaits(currentLocalityIndex)
- currentLocalityIndex += 1
+ // Remove the scheduled or finished tasks lazily
+ def tasksNeedToBeScheduledFrom(pendingTaskIds: ArrayBuffer[Int]): Boolean = {
+ var indexOffset = pendingTaskIds.size
+ while (indexOffset > 0) {
+ indexOffset -= 1
+ val index = pendingTaskIds(indexOffset)
+ if (copiesRunning(index) == 0 && !successful(index)) {
+ return true
+ } else {
+ pendingTaskIds.remove(indexOffset)
+ }
+ }
+ false
+ }
+ // Walk through the list of tasks that can be scheduled at each location and returns true
+ // if there are any tasks that still need to be scheduled. Lazily cleans up tasks that have
+ // already been scheduled.
+ def moreTasksToRunIn(pendingTasks: HashMap[String, ArrayBuffer[Int]]): Boolean = {
+ val emptyKeys = new ArrayBuffer[String]
+ val hasTasks = pendingTasks.exists {
+ case (id: String, tasks: ArrayBuffer[Int]) =>
+ if (tasksNeedToBeScheduledFrom(tasks)) {
+ true
+ } else {
+ emptyKeys += id
+ false
+ }
+ }
+ // The key could be executorId, host or rackId
+ emptyKeys.foreach(id => pendingTasks.remove(id))
+ hasTasks
+ }
+
+ while (currentLocalityIndex < myLocalityLevels.length - 1) {
+ val moreTasks = myLocalityLevels(currentLocalityIndex) match {
+ case TaskLocality.PROCESS_LOCAL => moreTasksToRunIn(pendingTasksForExecutor)
+ case TaskLocality.NODE_LOCAL => moreTasksToRunIn(pendingTasksForHost)
+ case TaskLocality.NO_PREF => pendingTasksWithNoPrefs.nonEmpty
+ case TaskLocality.RACK_LOCAL => moreTasksToRunIn(pendingTasksForRack)
+ }
+ if (!moreTasks) {
+ // This is a performance optimization: if there are no more tasks that can
+ // be scheduled at a particular locality level, there is no point in waiting
+ // for the locality wait timeout (SPARK-4939).
+ lastLaunchTime = curTime
+ logDebug(s"No tasks for locality level ${myLocalityLevels(currentLocalityIndex)}, " +
+ s"so moving to locality level ${myLocalityLevels(currentLocalityIndex + 1)}")
+ currentLocalityIndex += 1
+ } else if (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex)) {
+ // Jump to the next locality level, and reset lastLaunchTime so that the next locality
+ // wait timer doesn't immediately expire
+ lastLaunchTime += localityWaits(currentLocalityIndex)
+ currentLocalityIndex += 1
+ logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex)} after waiting for " +
+ s"${localityWaits(currentLocalityIndex)}ms")
+ } else {
+ return myLocalityLevels(currentLocalityIndex)
+ }
}
myLocalityLevels(currentLocalityIndex)
}
@@ -515,12 +581,33 @@ private[spark] class TaskSetManager(
index
}
+ /**
+ * Marks the task as getting result and notifies the DAG Scheduler
+ */
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
sched.dagScheduler.taskGettingResult(info)
}
+ /**
+ * Check whether has enough quota to fetch the result with `size` bytes
+ */
+ def canFetchMoreResults(size: Long): Boolean = sched.synchronized {
+ totalResultSize += size
+ calculatedTasks += 1
+ if (maxResultSize > 0 && totalResultSize > maxResultSize) {
+ val msg = s"Total size of serialized results of ${calculatedTasks} tasks " +
+ s"(${Utils.bytesToString(totalResultSize)}) is bigger than spark.driver.maxResultSize " +
+ s"(${Utils.bytesToString(maxResultSize)})"
+ logError(msg)
+ abort(msg)
+ false
+ } else {
+ true
+ }
+ }
+
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
@@ -635,7 +722,7 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}
- def abort(message: String) {
+ def abort(message: String): Unit = sched.synchronized {
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.dagScheduler.taskSetFailed(taskSet, message)
isZombie = true
@@ -679,7 +766,7 @@ private[spark] class TaskSetManager(
// Re-enqueue pending tasks for this host based on the status of the cluster. Note
// that it's okay if we add a task to the same queue twice (if it had multiple preferred
- // locations), because findTaskFromList will skip already-running tasks.
+ // locations), because dequeueTaskFromList will skip already-running tasks.
for (index <- getPendingTasksForExecutor(execId)) {
addPendingTask(index, readding=true)
}
@@ -687,10 +774,11 @@ private[spark] class TaskSetManager(
addPendingTask(index, readding=true)
}
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage.
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage,
+ // and we are not using an external shuffle server which could serve the shuffle outputs.
// The reason is the next stage wouldn't be able to fetch the data from this dead executor
// so we would need to rerun these tasks on other executors.
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (successful(index)) {
@@ -706,7 +794,7 @@ private[spark] class TaskSetManager(
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure)
+ handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId))
}
// recalculate valid locality levels and waits when executor is lost
recomputeLocality()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index fb8160abc59db..1da6fe976da5b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -66,7 +66,19 @@ private[spark] object CoarseGrainedClusterMessages {
case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
- case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase :String)
+ // Exchanged between the driver and the AM in Yarn client mode
+ case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String)
extends CoarseGrainedClusterMessage
+ // Messages exchanged between the driver and the cluster manager for executor allocation
+ // In Yarn mode, these are exchanged between the driver and the AM
+
+ case object RegisterClusterManager extends CoarseGrainedClusterMessage
+
+ // Request executors by specifying the new total number of executors desired
+ // This includes executors already pending or running
+ case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage
+
+ case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 59aed6b72fe42..103a5c053c289 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -27,11 +27,10 @@ import akka.actor._
import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState}
-import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer}
+import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
+import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils}
-import org.apache.spark.ui.JettyUtils
/**
* A scheduler backend that waits for coarse grained executors to connect to it through Akka.
@@ -42,11 +41,12 @@ import org.apache.spark.ui.JettyUtils
* (spark.deploy.*).
*/
private[spark]
-class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem)
- extends SchedulerBackend with Logging
+class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem)
+ extends ExecutorAllocationClient with SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
+ // Total number of executors that are currently registered
var totalRegisteredExecutors = new AtomicInteger(0)
val conf = scheduler.sc.conf
private val timeout = AkkaUtils.askTimeout(conf)
@@ -61,10 +61,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000)
val createTime = System.currentTimeMillis()
+ private val executorDataMap = new HashMap[String, ExecutorData]
+
+ // Number of executors requested from the cluster manager that have not registered yet
+ private var numPendingExecutors = 0
+
+ private val listenerBus = scheduler.sc.listenerBus
+
+ // Executors we have requested the cluster manager to kill that have not died yet
+ private val executorsPendingToRemove = new HashSet[String]
+
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive {
override protected def log = CoarseGrainedSchedulerBackend.this.log
private val addressToExecutorId = new HashMap[Address, String]
- private val executorDataMap = new HashMap[String, ExecutorData]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@@ -84,12 +93,23 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
} else {
logInfo("Registered executor: " + sender + " with ID " + executorId)
sender ! RegisteredExecutor
- executorDataMap.put(executorId, new ExecutorData(sender, sender.path.address,
- Utils.parseHostPort(hostPort)._1, cores, cores))
addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
+ val (host, _) = Utils.parseHostPort(hostPort)
+ val data = new ExecutorData(sender, sender.path.address, host, cores, cores)
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ CoarseGrainedSchedulerBackend.this.synchronized {
+ executorDataMap.put(executorId, data)
+ if (numPendingExecutors > 0) {
+ numPendingExecutors -= 1
+ logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
+ }
+ }
+ listenerBus.post(
+ SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
makeOffers()
}
@@ -111,7 +131,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
makeOffers()
case KillTask(taskId, executorId, interruptThread) =>
- executorDataMap(executorId).executorActor ! KillTask(taskId, executorId, interruptThread)
+ executorDataMap.get(executorId) match {
+ case Some(executorInfo) =>
+ executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread)
+ case None =>
+ // Ignoring the task kill since the executor is not registered.
+ logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
+ }
case StopDriver =>
sender ! true
@@ -128,10 +154,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
removeExecutor(executorId, reason)
sender ! true
- case AddWebUIFilter(filterName, filterParams, proxyBase) =>
- addWebUIFilter(filterName, filterParams, proxyBase)
- sender ! true
-
case DisassociatedEvent(_, address, _) =>
addressToExecutorId.get(address).foreach(removeExecutor(_,
"remote Akka client disassociated"))
@@ -183,13 +205,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
}
// Remove a disconnected slave from the cluster
- def removeExecutor(executorId: String, reason: String) {
+ def removeExecutor(executorId: String, reason: String): Unit = {
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
- executorDataMap -= executorId
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ CoarseGrainedSchedulerBackend.this.synchronized {
+ executorDataMap -= executorId
+ executorsPendingToRemove -= executorId
+ }
totalCoreCount.addAndGet(-executorInfo.totalCores)
+ totalRegisteredExecutors.addAndGet(-1)
scheduler.executorLost(executorId, SlaveLost(reason))
- case None => logError(s"Asked to remove non existant executor $executorId")
+ listenerBus.post(
+ SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason))
+ case None => logError(s"Asked to remove non-existent executor $executorId")
}
}
}
@@ -274,21 +304,62 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
false
}
- // Add filters to the SparkUI
- def addWebUIFilter(filterName: String, filterParams: Map[String, String], proxyBase: String) {
- if (proxyBase != null && proxyBase.nonEmpty) {
- System.setProperty("spark.ui.proxyBase", proxyBase)
- }
+ /**
+ * Return the number of executors currently registered with this backend.
+ */
+ def numExistingExecutors: Int = executorDataMap.size
+
+ /**
+ * Request an additional number of executors from the cluster manager.
+ * Return whether the request is acknowledged.
+ */
+ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
+ logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
+ logDebug(s"Number of pending executors is now $numPendingExecutors")
+ numPendingExecutors += numAdditionalExecutors
+ // Account for executors pending to be added or removed
+ val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size
+ doRequestTotalExecutors(newTotal)
+ }
- val hasFilter = (filterName != null && filterName.nonEmpty &&
- filterParams != null && filterParams.nonEmpty)
- if (hasFilter) {
- logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
- conf.set("spark.ui.filters", filterName)
- filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
- scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
+ /**
+ * Request executors from the cluster manager by specifying the total number desired,
+ * including existing pending and running executors.
+ *
+ * The semantics here guarantee that we do not over-allocate executors for this application,
+ * since a later request overrides the value of any prior request. The alternative interface
+ * of requesting a delta of executors risks double counting new executors when there are
+ * insufficient resources to satisfy the first request. We make the assumption here that the
+ * cluster manager will eventually fulfill all requests when resources free up.
+ *
+ * Return whether the request is acknowledged.
+ */
+ protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false
+
+ /**
+ * Request that the cluster manager kill the specified executors.
+ * Return whether the kill request is acknowledged.
+ */
+ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized {
+ logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
+ val filteredExecutorIds = new ArrayBuffer[String]
+ executorIds.foreach { id =>
+ if (executorDataMap.contains(id)) {
+ filteredExecutorIds += id
+ } else {
+ logWarning(s"Executor to kill $id does not exist!")
+ }
}
+ executorsPendingToRemove ++= filteredExecutorIds
+ doKillExecutors(filteredExecutorIds)
}
+
+ /**
+ * Kill the given list of executors through the cluster manager.
+ * Return whether the kill request is acknowledged.
+ */
+ protected def doKillExecutors(executorIds: Seq[String]): Boolean = false
+
}
private[spark] object CoarseGrainedSchedulerBackend {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
index b71bd5783d6df..eb52ddfb1eab1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
@@ -31,7 +31,7 @@ import akka.actor.{Address, ActorRef}
private[cluster] class ExecutorData(
val executorActor: ActorRef,
val executorAddress: Address,
- val executorHost: String ,
+ override val executorHost: String,
var freeCores: Int,
- val totalCores: Int
-)
+ override val totalCores: Int
+) extends ExecutorInfo(executorHost, totalCores)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala
new file mode 100644
index 0000000000000..b4738e64c9391
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Stores information about an executor to pass from the scheduler to SparkListeners.
+ */
+@DeveloperApi
+class ExecutorInfo(
+ val executorHost: String,
+ val totalCores: Int
+) {
+
+ def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo]
+
+ override def equals(other: Any): Boolean = other match {
+ case that: ExecutorInfo =>
+ (that canEqual this) &&
+ executorHost == that.executorHost &&
+ totalCores == that.totalCores
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ val state = Seq(executorHost, totalCores)
+ state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index ee10aa061f4e9..06786a59524e7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -22,6 +22,7 @@ import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.{Logging, SparkContext, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.util.AkkaUtils
private[spark] class SimrSchedulerBackend(
scheduler: TaskSchedulerImpl,
@@ -38,7 +39,8 @@ private[spark] class SimrSchedulerBackend(
override def start() {
super.start()
- val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ val driverUrl = AkkaUtils.address(
+ AkkaUtils.protocol(actorSystem),
SparkEnv.driverActorSystemName,
sc.conf.get("spark.driver.host"),
sc.conf.get("spark.driver.port"),
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 8c7de75600b5f..d2e1680a5fd1b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -21,7 +21,7 @@ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
private[spark] class SparkDeploySchedulerBackend(
scheduler: TaskSchedulerImpl,
@@ -46,7 +46,8 @@ private[spark] class SparkDeploySchedulerBackend(
super.start()
// The endpoint for executors to talk to us
- val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ val driverUrl = AkkaUtils.address(
+ AkkaUtils.protocol(actorSystem),
SparkEnv.driverActorSystemName,
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
@@ -55,19 +56,26 @@ private[spark] class SparkDeploySchedulerBackend(
"{{WORKER_URL}}")
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
.map(Utils.splitCommandString).getOrElse(Seq.empty)
- val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp =>
- cp.split(java.io.File.pathSeparator)
- }
- val libraryPathEntries =
- sc.conf.getOption("spark.executor.extraLibraryPath").toSeq.flatMap { cp =>
- cp.split(java.io.File.pathSeparator)
+ val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath")
+ .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
+ val libraryPathEntries = sc.conf.getOption("spark.executor.extraLibraryPath")
+ .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
+
+ // When testing, expose the parent class path to the child. This is processed by
+ // compute-classpath.{cmd,sh} and makes all needed jars available to child processes
+ // when the assembly is built with the "*-provided" profiles enabled.
+ val testingClassPath =
+ if (sys.props.contains("spark.testing")) {
+ sys.props("java.class.path").split(java.io.File.pathSeparator).toSeq
+ } else {
+ Nil
}
// Start executors with a few necessary configs for registering with the scheduler
val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf)
val javaOpts = sparkJavaOpts ++ extraJavaOpts
val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
- args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts)
+ args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts)
val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
appUIAddress, sc.eventLogDir)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
new file mode 100644
index 0000000000000..f14aaeea0a25c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import scala.concurrent.{Future, ExecutionContext}
+
+import akka.actor.{Actor, ActorRef, Props}
+import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.ui.JettyUtils
+import org.apache.spark.util.{AkkaUtils, Utils}
+
+import scala.util.control.NonFatal
+
+/**
+ * Abstract Yarn scheduler backend that contains common logic
+ * between the client and cluster Yarn scheduler backends.
+ */
+private[spark] abstract class YarnSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) {
+
+ if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
+ minRegisteredRatio = 0.8
+ }
+
+ protected var totalExpectedExecutors = 0
+
+ private val yarnSchedulerActor: ActorRef =
+ actorSystem.actorOf(
+ Props(new YarnSchedulerActor),
+ name = YarnSchedulerBackend.ACTOR_NAME)
+
+ private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf)
+
+ /**
+ * Request executors from the ApplicationMaster by specifying the total number desired.
+ * This includes executors already pending or running.
+ */
+ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
+ AkkaUtils.askWithReply[Boolean](
+ RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout)
+ }
+
+ /**
+ * Request that the ApplicationMaster kill the specified executors.
+ */
+ override def doKillExecutors(executorIds: Seq[String]): Boolean = {
+ AkkaUtils.askWithReply[Boolean](
+ KillExecutors(executorIds), yarnSchedulerActor, askTimeout)
+ }
+
+ override def sufficientResourcesRegistered(): Boolean = {
+ totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio
+ }
+
+ /**
+ * Add filters to the SparkUI.
+ */
+ private def addWebUIFilter(
+ filterName: String,
+ filterParams: Map[String, String],
+ proxyBase: String): Unit = {
+ if (proxyBase != null && proxyBase.nonEmpty) {
+ System.setProperty("spark.ui.proxyBase", proxyBase)
+ }
+
+ val hasFilter =
+ filterName != null && filterName.nonEmpty &&
+ filterParams != null && filterParams.nonEmpty
+ if (hasFilter) {
+ logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
+ conf.set("spark.ui.filters", filterName)
+ filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
+ scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
+ }
+ }
+
+ /**
+ * An actor that communicates with the ApplicationMaster.
+ */
+ private class YarnSchedulerActor extends Actor {
+ private var amActor: Option[ActorRef] = None
+
+ implicit val askAmActorExecutor = ExecutionContext.fromExecutor(
+ Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor"))
+
+ override def preStart(): Unit = {
+ // Listen for disassociation events
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ }
+
+ override def receive = {
+ case RegisterClusterManager =>
+ logInfo(s"ApplicationMaster registered as $sender")
+ amActor = Some(sender)
+
+ case r: RequestExecutors =>
+ amActor match {
+ case Some(actor) =>
+ val driverActor = sender
+ Future {
+ driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ } onFailure {
+ case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e)
+ }
+ case None =>
+ logWarning("Attempted to request executors before the AM has registered!")
+ sender ! false
+ }
+
+ case k: KillExecutors =>
+ amActor match {
+ case Some(actor) =>
+ val driverActor = sender
+ Future {
+ driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ } onFailure {
+ case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e)
+ }
+ case None =>
+ logWarning("Attempted to kill executors before the AM has registered!")
+ sender ! false
+ }
+
+ case AddWebUIFilter(filterName, filterParams, proxyBase) =>
+ addWebUIFilter(filterName, filterParams, proxyBase)
+ sender ! true
+
+ case d: DisassociatedEvent =>
+ if (amActor.isDefined && sender == amActor.get) {
+ logWarning(s"ApplicationMaster has disassociated: $d")
+ }
+ }
+ }
+}
+
+private[spark] object YarnSchedulerBackend {
+ val ACTOR_NAME = "YarnScheduler"
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index d7f88de4b40aa..0d1c2a916ca7f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -31,6 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas
import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException}
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.util.{Utils, AkkaUtils}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -92,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = CoarseMesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try { {
val ret = driver.run()
@@ -120,16 +121,18 @@ private[spark] class CoarseMesosSchedulerBackend(
environment.addVariables(
Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
}
- val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions")
+ val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "")
- val libraryPathOption = "spark.executor.extraLibraryPath"
- val extraLibraryPath = conf.getOption(libraryPathOption).map(p => s"-Djava.library.path=$p")
- val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ")
+ // Set the environment variable through a command prefix
+ // to append to the existing value of the variable
+ val prefixEnv = conf.getOption("spark.executor.extraLibraryPath").map { p =>
+ Utils.libraryPathEnvPrefix(Seq(p))
+ }.getOrElse("")
environment.addVariables(
Environment.Variable.newBuilder()
.setName("SPARK_EXECUTOR_OPTS")
- .setValue(extraOpts)
+ .setValue(extraJavaOpts)
.build())
sc.executorEnvs.foreach { case (key, value) =>
@@ -140,7 +143,8 @@ private[spark] class CoarseMesosSchedulerBackend(
}
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
- val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ val driverUrl = AkkaUtils.address(
+ AkkaUtils.protocol(sc.env.actorSystem),
SparkEnv.driverActorSystemName,
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
@@ -150,16 +154,17 @@ private[spark] class CoarseMesosSchedulerBackend(
if (uri == null) {
val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format(
- runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores, appId))
+ "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format(
+ prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue,
+ offer.getHostname, numCores, appId))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- ("cd %s*; " +
+ ("cd %s*; %s " +
"./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s")
- .format(basename, driverUrl, offer.getSlaveId.getValue,
+ .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue,
offer.getHostname, numCores, appId))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
@@ -238,8 +243,7 @@ private[spark] class CoarseMesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Build a Mesos resource protobuf object */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index e0f2fd622f54c..c3c546be6da15 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -22,14 +22,17 @@ import java.util.{ArrayList => JArrayList, List => JList}
import java.util.Collections
import scala.collection.JavaConversions._
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.mutable.{HashMap, HashSet}
import org.apache.mesos.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
+import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState,
+ ExecutorInfo => MesosExecutorInfo, _}
+import org.apache.spark.executor.MesosExecutorBackend
import org.apache.spark.{Logging, SparkContext, SparkException, TaskState}
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
@@ -62,6 +65,9 @@ private[spark] class MesosSchedulerBackend(
var classLoader: ClassLoader = null
+ // The listener bus to publish executor added/removed events.
+ val listenerBus = sc.listenerBus
+
@volatile var appId: String = _
override def start() {
@@ -72,7 +78,7 @@ private[spark] class MesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = MesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try {
val ret = driver.run()
@@ -87,7 +93,7 @@ private[spark] class MesosSchedulerBackend(
}
}
- def createExecutorInfo(execId: String): ExecutorInfo = {
+ def createExecutorInfo(execId: String): MesosExecutorInfo = {
val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home")
.orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility
.getOrElse {
@@ -98,15 +104,16 @@ private[spark] class MesosSchedulerBackend(
environment.addVariables(
Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
}
- val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
- val extraLibraryPath = sc.conf.getOption("spark.executor.extraLibraryPath").map { lp =>
- s"-Djava.library.path=$lp"
- }
- val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ")
+ val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("")
+
+ val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p =>
+ Utils.libraryPathEnvPrefix(Seq(p))
+ }.getOrElse("")
+
environment.addVariables(
Environment.Variable.newBuilder()
.setName("SPARK_EXECUTOR_OPTS")
- .setValue(extraOpts)
+ .setValue(extraJavaOpts)
.build())
sc.executorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
@@ -117,13 +124,15 @@ private[spark] class MesosSchedulerBackend(
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
val uri = sc.conf.get("spark.executor.uri", null)
+ val executorBackendName = classOf[MesosExecutorBackend].getName
if (uri == null) {
- command.setValue(new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath)
+ val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath
+ command.setValue(s"$prefixEnv $executorPath $executorBackendName")
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
- command.setValue("cd %s*; ./sbin/spark-executor".format(basename))
+ command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName")
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
val cpus = Resource.newBuilder()
@@ -139,7 +148,7 @@ private[spark] class MesosSchedulerBackend(
Value.Scalar.newBuilder()
.setValue(MemoryUtils.calculateTotalMemory(sc)).build())
.build()
- ExecutorInfo.newBuilder()
+ MesosExecutorInfo.newBuilder()
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
@@ -164,29 +173,16 @@ private[spark] class MesosSchedulerBackend(
execArgs
}
- private def setClassLoader(): ClassLoader = {
- val oldClassLoader = Thread.currentThread.getContextClassLoader
- Thread.currentThread.setContextClassLoader(classLoader)
- oldClassLoader
- }
-
- private def restoreClassLoader(oldClassLoader: ClassLoader) {
- Thread.currentThread.setContextClassLoader(oldClassLoader)
- }
-
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
appId = frameworkId.getValue
logInfo("Registered as framework ID " + appId)
registeredLock.synchronized {
isRegistered = true
registeredLock.notifyAll()
}
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
@@ -198,6 +194,16 @@ private[spark] class MesosSchedulerBackend(
}
}
+ private def inClassLoader()(fun: => Unit) = {
+ val oldClassLoader = Thread.currentThread.getContextClassLoader
+ Thread.currentThread.setContextClassLoader(classLoader)
+ try {
+ fun
+ } finally {
+ Thread.currentThread.setContextClassLoader(oldClassLoader)
+ }
+ }
+
override def disconnected(d: SchedulerDriver) {}
override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
@@ -208,66 +214,75 @@ private[spark] class MesosSchedulerBackend(
* tasks are balanced across the cluster.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
- val oldClassLoader = setClassLoader()
- try {
- synchronized {
- // Build a big list of the offerable workers, and remember their indices so that we can
- // figure out which Offer to reply to for each worker
- val offerableWorkers = new ArrayBuffer[WorkerOffer]
- val offerableIndices = new HashMap[String, Int]
-
- def sufficientOffer(o: Offer) = {
- val mem = getResource(o.getResourcesList, "mem")
- val cpus = getResource(o.getResourcesList, "cpus")
- val slaveId = o.getSlaveId.getValue
- (mem >= MemoryUtils.calculateTotalMemory(sc) &&
- // need at least 1 for executor, 1 for task
- cpus >= 2 * scheduler.CPUS_PER_TASK) ||
- (slaveIdsWithExecutors.contains(slaveId) &&
- cpus >= scheduler.CPUS_PER_TASK)
- }
+ inClassLoader() {
+ // Fail-fast on offers we know will be rejected
+ val (usableOffers, unUsableOffers) = offers.partition { o =>
+ val mem = getResource(o.getResourcesList, "mem")
+ val cpus = getResource(o.getResourcesList, "cpus")
+ val slaveId = o.getSlaveId.getValue
+ // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK?
+ (mem >= MemoryUtils.calculateTotalMemory(sc) &&
+ // need at least 1 for executor, 1 for task
+ cpus >= 2 * scheduler.CPUS_PER_TASK) ||
+ (slaveIdsWithExecutors.contains(slaveId) &&
+ cpus >= scheduler.CPUS_PER_TASK)
+ }
- for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) {
- val slaveId = offer.getSlaveId.getValue
- offerableIndices.put(slaveId, index)
- val cpus = if (slaveIdsWithExecutors.contains(slaveId)) {
- getResource(offer.getResourcesList, "cpus").toInt
- } else {
- // If the executor doesn't exist yet, subtract CPU for executor
- getResource(offer.getResourcesList, "cpus").toInt -
- scheduler.CPUS_PER_TASK
- }
- offerableWorkers += new WorkerOffer(
- offer.getSlaveId.getValue,
- offer.getHostname,
- cpus)
+ val workerOffers = usableOffers.map { o =>
+ val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) {
+ getResource(o.getResourcesList, "cpus").toInt
+ } else {
+ // If the executor doesn't exist yet, subtract CPU for executor
+ // TODO(pwendell): Should below just subtract "1"?
+ getResource(o.getResourcesList, "cpus").toInt -
+ scheduler.CPUS_PER_TASK
}
+ new WorkerOffer(
+ o.getSlaveId.getValue,
+ o.getHostname,
+ cpus)
+ }
+
+ val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap
+ val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap
+
+ val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]]
- // Call into the TaskSchedulerImpl
- val taskLists = scheduler.resourceOffers(offerableWorkers)
-
- // Build a list of Mesos tasks for each slave
- val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]())
- for ((taskList, index) <- taskLists.zipWithIndex) {
- if (!taskList.isEmpty) {
- for (taskDesc <- taskList) {
- val slaveId = taskDesc.executorId
- val offerNum = offerableIndices(slaveId)
- slaveIdsWithExecutors += slaveId
- taskIdToSlaveId(taskDesc.taskId) = slaveId
- mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId))
- }
+ val slavesIdsOfAcceptedOffers = HashSet[String]()
+
+ // Call into the TaskSchedulerImpl
+ val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty)
+ acceptedOffers
+ .foreach { offer =>
+ offer.foreach { taskDesc =>
+ val slaveId = taskDesc.executorId
+ slaveIdsWithExecutors += slaveId
+ slavesIdsOfAcceptedOffers += slaveId
+ taskIdToSlaveId(taskDesc.taskId) = slaveId
+ mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
+ .add(createMesosTask(taskDesc, slaveId))
}
}
- // Reply to the offers
- val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
- for (i <- 0 until offers.size) {
- d.launchTasks(Collections.singleton(offers(i).getId), mesosTasks(i), filters)
- }
+ // Reply to the offers
+ val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
+
+ mesosTasks.foreach { case (slaveId, tasks) =>
+ slaveIdToWorkerOffer.get(slaveId).foreach(o =>
+ listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId,
+ new ExecutorInfo(o.host, o.cores)))
+ )
+ d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters)
}
- } finally {
- restoreClassLoader(oldClassLoader)
+
+ // Decline offers that weren't used
+ // NOTE: This logic assumes that we only get a single offer for each host in a given batch
+ for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) {
+ d.declineOffer(o.getId)
+ }
+
+ // Decline offers we ruled out immediately
+ unUsableOffers.foreach(o => d.declineOffer(o.getId))
}
}
@@ -276,8 +291,7 @@ private[spark] class MesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Turn a Spark TaskDescription into a Mesos task */
@@ -294,7 +308,7 @@ private[spark] class MesosSchedulerBackend(
.setExecutor(createExecutorInfo(slaveId))
.setName(task.name)
.addResources(cpuResource)
- .setData(ByteString.copyFrom(task.serializedTask))
+ .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString)
.build()
}
@@ -307,32 +321,26 @@ private[spark] class MesosSchedulerBackend(
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
val tid = status.getTaskId.getValue.toLong
val state = TaskState.fromMesos(status.getState)
synchronized {
if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
// We lost the executor on this slave, so remember that it's gone
- slaveIdsWithExecutors -= taskIdToSlaveId(tid)
+ removeExecutor(taskIdToSlaveId(tid), "Lost executor")
}
if (isFinished(status.getState)) {
taskIdToSlaveId.remove(tid)
}
}
scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer)
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
override def error(d: SchedulerDriver, message: String) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
logError("Mesos error: " + message)
scheduler.error(message)
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
@@ -348,16 +356,21 @@ private[spark] class MesosSchedulerBackend(
override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
+ /**
+ * Remove executor associated with slaveId in a thread safe manner.
+ */
+ private def removeExecutor(slaveId: String, reason: String) = {
+ synchronized {
+ listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason))
+ slaveIdsWithExecutors -= slaveId
+ }
+ }
+
private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
logInfo("Mesos slave lost: " + slaveId.getValue)
- synchronized {
- slaveIdsWithExecutors -= slaveId.getValue
- }
+ removeExecutor(slaveId.getValue, reason.toString)
scheduler.executorLost(slaveId.getValue, reason)
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
new file mode 100644
index 0000000000000..5e7e6567a3e06
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster.mesos
+
+import java.nio.ByteBuffer
+
+import org.apache.mesos.protobuf.ByteString
+
+import org.apache.spark.Logging
+
+/**
+ * Wrapper for serializing the data sent when launching Mesos tasks.
+ */
+private[spark] case class MesosTaskLaunchData(
+ serializedTask: ByteBuffer,
+ attemptNumber: Int) extends Logging {
+
+ def toByteString: ByteString = {
+ val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit)
+ dataBuffer.putInt(attemptNumber)
+ dataBuffer.put(serializedTask)
+ dataBuffer.rewind
+ logDebug(s"ByteBuffer size: [${dataBuffer.remaining}]")
+ ByteString.copyFrom(dataBuffer)
+ }
+}
+
+private[spark] object MesosTaskLaunchData extends Logging {
+ def fromByteString(byteString: ByteString): MesosTaskLaunchData = {
+ val byteBuffer = byteString.asReadOnlyByteBuffer()
+ logDebug(s"ByteBuffer size: [${byteBuffer.remaining}]")
+ val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes
+ val serializedTask = byteBuffer.slice() // subsequence starting at the current position
+ MesosTaskLaunchData(serializedTask, attemptNumber)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 58b78f041cd85..4676b828d3d89 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -19,9 +19,11 @@ package org.apache.spark.scheduler.local
import java.nio.ByteBuffer
+import scala.concurrent.duration._
+
import akka.actor.{Actor, ActorRef, Props}
-import org.apache.spark.{Logging, SparkEnv, TaskState}
+import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
@@ -41,17 +43,20 @@ private case class StopExecutor()
* and the TaskSchedulerImpl.
*/
private[spark] class LocalActor(
- scheduler: TaskSchedulerImpl,
- executorBackend: LocalBackend,
- private val totalCores: Int) extends Actor with ActorLogReceive with Logging {
+ scheduler: TaskSchedulerImpl,
+ executorBackend: LocalBackend,
+ private val totalCores: Int)
+ extends Actor with ActorLogReceive with Logging {
+
+ import context.dispatcher // to use Akka's scheduler.scheduleOnce()
private var freeCores = totalCores
- private val localExecutorId = "localhost"
+ private val localExecutorId = SparkContext.DRIVER_IDENTIFIER
private val localExecutorHostname = "localhost"
- val executor = new Executor(
- localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true)
+ private val executor = new Executor(
+ localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true)
override def receiveWithLogging = {
case ReviveOffers =>
@@ -73,9 +78,15 @@ private[spark] class LocalActor(
def reviveOffers() {
val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
- for (task <- scheduler.resourceOffers(offers).flatten) {
+ val tasks = scheduler.resourceOffers(offers).flatten
+ for (task <- tasks) {
freeCores -= scheduler.CPUS_PER_TASK
- executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask)
+ executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber,
+ task.name, task.serializedTask)
+ }
+ if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) {
+ // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout
+ context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 554a33ce7f1a6..1baa0e009f3ae 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -27,7 +27,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils
-private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
+private[spark] class JavaSerializationStream(
+ out: OutputStream, counterReset: Int, extraDebugInfo: Boolean)
extends SerializationStream {
private val objOut = new ObjectOutputStream(out)
private var counter = 0
@@ -39,7 +40,12 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
* the stream 'resets' object class descriptions have to be re-written)
*/
def writeObject[T: ClassTag](t: T): SerializationStream = {
- objOut.writeObject(t)
+ try {
+ objOut.writeObject(t)
+ } catch {
+ case e: NotSerializableException if extraDebugInfo =>
+ throw SerializationDebugger.improveException(t, e)
+ }
counter += 1
if (counterReset > 0 && counter >= counterReset) {
objOut.reset()
@@ -64,7 +70,8 @@ extends DeserializationStream {
}
-private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
+private[spark] class JavaSerializerInstance(
+ counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader)
extends SerializerInstance {
override def serialize[T: ClassTag](t: T): ByteBuffer = {
@@ -88,11 +95,11 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade
}
override def serializeStream(s: OutputStream): SerializationStream = {
- new JavaSerializationStream(s, counterReset)
+ new JavaSerializationStream(s, counterReset, extraDebugInfo)
}
override def deserializeStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
+ new JavaDeserializationStream(s, defaultClassLoader)
}
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
@@ -111,17 +118,20 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade
@DeveloperApi
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
+ private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true)
override def newInstance(): SerializerInstance = {
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
- new JavaSerializerInstance(counterReset, classLoader)
+ new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader)
}
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(counterReset)
+ out.writeBoolean(extraDebugInfo)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
counterReset = in.readInt()
+ extraDebugInfo = in.readBoolean()
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 621a951c27d07..d56e23ce4478a 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -26,9 +26,10 @@ import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializ
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.spark._
+import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock}
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.collection.CompactBuffer
@@ -90,6 +91,7 @@ class KryoSerializer(conf: SparkConf)
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
+ kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
try {
// Use the default classloader when calling the user registrator.
@@ -205,7 +207,8 @@ private[serializer] object KryoSerializer {
classOf[PutBlock],
classOf[GotBlock],
classOf[GetBlock],
- classOf[MapStatus],
+ classOf[CompressedMapStatus],
+ classOf[HighlyCompressedMapStatus],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
classOf[Array[Byte]],
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
new file mode 100644
index 0000000000000..cecb992579655
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField}
+import java.lang.reflect.{Field, Method}
+import java.security.AccessController
+
+import scala.annotation.tailrec
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+
+private[serializer] object SerializationDebugger extends Logging {
+
+ /**
+ * Improve the given NotSerializableException with the serialization path leading from the given
+ * object to the problematic object. This is turned off automatically if
+ * `sun.io.serialization.extendedDebugInfo` flag is turned on for the JVM.
+ */
+ def improveException(obj: Any, e: NotSerializableException): NotSerializableException = {
+ if (enableDebugging && reflect != null) {
+ new NotSerializableException(
+ e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n"))
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Find the path leading to a not serializable object. This method is modeled after OpenJDK's
+ * serialization mechanism, and handles the following cases:
+ * - primitives
+ * - arrays of primitives
+ * - arrays of non-primitive objects
+ * - Serializable objects
+ * - Externalizable objects
+ * - writeReplace
+ *
+ * It does not yet handle writeObject override, but that shouldn't be too hard to do either.
+ */
+ def find(obj: Any): List[String] = {
+ new SerializationDebugger().visit(obj, List.empty)
+ }
+
+ private[serializer] var enableDebugging: Boolean = {
+ !AccessController.doPrivileged(new sun.security.action.GetBooleanAction(
+ "sun.io.serialization.extendedDebugInfo")).booleanValue()
+ }
+
+ private class SerializationDebugger {
+
+ /** A set to track the list of objects we have visited, to avoid cycles in the graph. */
+ private val visited = new mutable.HashSet[Any]
+
+ /**
+ * Visit the object and its fields and stop when we find an object that is not serializable.
+ * Return the path as a list. If everything can be serialized, return an empty list.
+ */
+ def visit(o: Any, stack: List[String]): List[String] = {
+ if (o == null) {
+ List.empty
+ } else if (visited.contains(o)) {
+ List.empty
+ } else {
+ visited += o
+ o match {
+ // Primitive value, string, and primitive arrays are always serializable
+ case _ if o.getClass.isPrimitive => List.empty
+ case _: String => List.empty
+ case _ if o.getClass.isArray && o.getClass.getComponentType.isPrimitive => List.empty
+
+ // Traverse non primitive array.
+ case a: Array[_] if o.getClass.isArray && !o.getClass.getComponentType.isPrimitive =>
+ val elem = s"array (class ${a.getClass.getName}, size ${a.length})"
+ visitArray(o.asInstanceOf[Array[_]], elem :: stack)
+
+ case e: java.io.Externalizable =>
+ val elem = s"externalizable object (class ${e.getClass.getName}, $e)"
+ visitExternalizable(e, elem :: stack)
+
+ case s: Object with java.io.Serializable =>
+ val elem = s"object (class ${s.getClass.getName}, $s)"
+ visitSerializable(s, elem :: stack)
+
+ case _ =>
+ // Found an object that is not serializable!
+ s"object not serializable (class: ${o.getClass.getName}, value: $o)" :: stack
+ }
+ }
+ }
+
+ private def visitArray(o: Array[_], stack: List[String]): List[String] = {
+ var i = 0
+ while (i < o.length) {
+ val childStack = visit(o(i), s"element of array (index: $i)" :: stack)
+ if (childStack.nonEmpty) {
+ return childStack
+ }
+ i += 1
+ }
+ return List.empty
+ }
+
+ private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] =
+ {
+ val fieldList = new ListObjectOutput
+ o.writeExternal(fieldList)
+ val childObjects = fieldList.outputArray
+ var i = 0
+ while (i < childObjects.length) {
+ val childStack = visit(childObjects(i), "writeExternal data" :: stack)
+ if (childStack.nonEmpty) {
+ return childStack
+ }
+ i += 1
+ }
+ return List.empty
+ }
+
+ private def visitSerializable(o: Object, stack: List[String]): List[String] = {
+ // An object contains multiple slots in serialization.
+ // Get the slots and visit fields in all of them.
+ val (finalObj, desc) = findObjectAndDescriptor(o)
+ val slotDescs = desc.getSlotDescs
+ var i = 0
+ while (i < slotDescs.length) {
+ val slotDesc = slotDescs(i)
+ if (slotDesc.hasWriteObjectMethod) {
+ // TODO: Handle classes that specify writeObject method.
+ } else {
+ val fields: Array[ObjectStreamField] = slotDesc.getFields
+ val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields)
+ val numPrims = fields.length - objFieldValues.length
+ desc.getObjFieldValues(finalObj, objFieldValues)
+
+ var j = 0
+ while (j < objFieldValues.length) {
+ val fieldDesc = fields(numPrims + j)
+ val elem = s"field (class: ${slotDesc.getName}" +
+ s", name: ${fieldDesc.getName}" +
+ s", type: ${fieldDesc.getType})"
+ val childStack = visit(objFieldValues(j), elem :: stack)
+ if (childStack.nonEmpty) {
+ return childStack
+ }
+ j += 1
+ }
+
+ }
+ i += 1
+ }
+ return List.empty
+ }
+ }
+
+ /**
+ * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles
+ * writeReplace in Serializable. It starts with the object itself, and keeps calling the
+ * writeReplace method until there is no more
+ */
+ @tailrec
+ private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = {
+ val cl = o.getClass
+ val desc = ObjectStreamClass.lookupAny(cl)
+ if (!desc.hasWriteReplaceMethod) {
+ (o, desc)
+ } else {
+ // write place
+ findObjectAndDescriptor(desc.invokeWriteReplace(o))
+ }
+ }
+
+ /**
+ * A dummy [[ObjectOutput]] that simply saves the list of objects written by a writeExternal
+ * call, and returns them through `outputArray`.
+ */
+ private class ListObjectOutput extends ObjectOutput {
+ private val output = new mutable.ArrayBuffer[Any]
+ def outputArray: Array[Any] = output.toArray
+ override def writeObject(o: Any): Unit = output += o
+ override def flush(): Unit = {}
+ override def write(i: Int): Unit = {}
+ override def write(bytes: Array[Byte]): Unit = {}
+ override def write(bytes: Array[Byte], i: Int, i1: Int): Unit = {}
+ override def close(): Unit = {}
+ override def writeFloat(v: Float): Unit = {}
+ override def writeChars(s: String): Unit = {}
+ override def writeDouble(v: Double): Unit = {}
+ override def writeUTF(s: String): Unit = {}
+ override def writeShort(i: Int): Unit = {}
+ override def writeInt(i: Int): Unit = {}
+ override def writeBoolean(b: Boolean): Unit = {}
+ override def writeBytes(s: String): Unit = {}
+ override def writeChar(i: Int): Unit = {}
+ override def writeLong(l: Long): Unit = {}
+ override def writeByte(i: Int): Unit = {}
+ }
+
+ /** An implicit class that allows us to call private methods of ObjectStreamClass. */
+ implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal {
+ def getSlotDescs: Array[ObjectStreamClass] = {
+ reflect.GetClassDataLayout.invoke(desc).asInstanceOf[Array[Object]].map {
+ classDataSlot => reflect.DescField.get(classDataSlot).asInstanceOf[ObjectStreamClass]
+ }
+ }
+
+ def hasWriteObjectMethod: Boolean = {
+ reflect.HasWriteObjectMethod.invoke(desc).asInstanceOf[Boolean]
+ }
+
+ def hasWriteReplaceMethod: Boolean = {
+ reflect.HasWriteReplaceMethod.invoke(desc).asInstanceOf[Boolean]
+ }
+
+ def invokeWriteReplace(obj: Object): Object = {
+ reflect.InvokeWriteReplace.invoke(desc, obj)
+ }
+
+ def getNumObjFields: Int = {
+ reflect.GetNumObjFields.invoke(desc).asInstanceOf[Int]
+ }
+
+ def getObjFieldValues(obj: Object, out: Array[Object]): Unit = {
+ reflect.GetObjFieldValues.invoke(desc, obj, out)
+ }
+ }
+
+ /**
+ * Object to hold all the reflection objects. If we run on a JVM that we cannot understand,
+ * this field will be null and this the debug helper should be disabled.
+ */
+ private val reflect: ObjectStreamClassReflection = try {
+ new ObjectStreamClassReflection
+ } catch {
+ case e: Exception =>
+ logWarning("Cannot find private methods using reflection", e)
+ null
+ }
+
+ private class ObjectStreamClassReflection {
+ /** ObjectStreamClass.getClassDataLayout */
+ val GetClassDataLayout: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("getClassDataLayout")
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.hasWriteObjectMethod */
+ val HasWriteObjectMethod: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteObjectMethod")
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.hasWriteReplaceMethod */
+ val HasWriteReplaceMethod: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteReplaceMethod")
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.invokeWriteReplace */
+ val InvokeWriteReplace: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("invokeWriteReplace", classOf[Object])
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.getNumObjFields */
+ val GetNumObjFields: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("getNumObjFields")
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.getObjFieldValues */
+ val GetObjFieldValues: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod(
+ "getObjFieldValues", classOf[Object], classOf[Array[Object]])
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass$ClassDataSlot.desc field */
+ val DescField: Field = {
+ val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc")
+ f.setAccessible(true)
+ f
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index a9144cdd97b8c..ca6e971d227fb 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -17,14 +17,14 @@
package org.apache.spark.serializer
-import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
+import java.io._
import java.nio.ByteBuffer
import scala.reflect.ClassTag
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
+import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}
/**
* :: DeveloperApi ::
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 71c08e9d5a8c3..be184464e0ae9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.{FetchFailed, TaskEndReason}
+import org.apache.spark.util.Utils
/**
* Failed to fetch a shuffle block. The executor catches this exception and propagates it
@@ -30,13 +31,22 @@ private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
shuffleId: Int,
mapId: Int,
- reduceId: Int)
- extends Exception {
+ reduceId: Int,
+ message: String,
+ cause: Throwable = null)
+ extends Exception(message, cause) {
- override def getMessage: String =
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+ def this(
+ bmAddress: BlockManagerId,
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int,
+ cause: Throwable) {
+ this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
+ }
- def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+ def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
+ Utils.exceptionString(this))
}
/**
@@ -46,7 +56,4 @@ private[spark] class MetadataFetchFailedException(
shuffleId: Int,
reduceId: Int,
message: String)
- extends FetchFailedException(null, shuffleId, -1, reduceId) {
-
- override def getMessage: String = message
-}
+ extends FetchFailedException(null, shuffleId, -1, reduceId, message)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 439981d232349..7de2f9cbb2866 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -24,9 +24,10 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._
-import org.apache.spark.{SparkEnv, SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
@@ -62,11 +63,14 @@ private[spark] trait ShuffleWriterGroup {
* each block stored in each file. In order to find the location of a shuffle block, we search the
* files within a ShuffleFileGroups associated with the block's reducer.
*/
-
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData().
private[spark]
class FileShuffleBlockManager(conf: SparkConf)
extends ShuffleBlockManager with Logging {
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
+
private lazy val blockManager = SparkEnv.get.blockManager
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
@@ -181,13 +185,14 @@ class FileShuffleBlockManager(conf: SparkConf)
val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId)
if (segmentOpt.isDefined) {
val segment = segmentOpt.get
- return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length)
+ return new FileSegmentManagedBuffer(
+ transportConf, segment.file, segment.offset, segment.length)
}
}
throw new IllegalStateException("Failed to find shuffle block: " + blockId)
} else {
val file = blockManager.diskBlockManager.getFile(blockId)
- new FileSegmentManagedBuffer(file, 0, file.length)
+ new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index 4ab34336d3f01..b292587d37028 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -20,8 +20,11 @@ package org.apache.spark.shuffle
import java.io._
import java.nio.ByteBuffer
-import org.apache.spark.SparkEnv
-import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer}
+import com.google.common.io.ByteStreams
+
+import org.apache.spark.{SparkConf, SparkEnv}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.storage._
/**
@@ -33,11 +36,15 @@ import org.apache.spark.storage._
* as the filename postfix for data file, and ".index" as the filename postfix for index file.
*
*/
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData().
private[spark]
-class IndexShuffleBlockManager extends ShuffleBlockManager {
+class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager {
private lazy val blockManager = SparkEnv.get.blockManager
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
+
/**
* Mapping to a single shuffleBlockId with reduce ID 0.
* */
@@ -101,10 +108,11 @@ class IndexShuffleBlockManager extends ShuffleBlockManager {
val in = new DataInputStream(new FileInputStream(indexFile))
try {
- in.skip(blockId.reduceId * 8)
+ ByteStreams.skipFully(in, blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
new FileSegmentManagedBuffer(
+ transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
offset,
nextOffset - offset)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
index 63863cc0250a3..b521f0c7fc77e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -18,8 +18,7 @@
package org.apache.spark.shuffle
import java.nio.ByteBuffer
-
-import org.apache.spark.network.ManagedBuffer
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.ShuffleBlockId
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index 801ae54086053..a44a8e1249256 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -20,8 +20,8 @@ package org.apache.spark.shuffle
import org.apache.spark.{TaskContext, ShuffleDependency}
/**
- * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the
- * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles
+ * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the driver
+ * and on each executor, based on the spark.shuffle.manager setting. The driver registers shuffles
* with it, and executors (or tasks running locally in the driver) can ask to read and write data.
*
* NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index ee91a368b76ea..3bcc7178a3d8b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -66,8 +66,9 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
val curMem = threadMemory(threadId)
val freeMemory = maxMemory - threadMemory.values.sum
- // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads
- val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem)
+ // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads;
+ // don't let it be negative
+ val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem))
if (curMem < maxMemory / (2 * numActiveThreads)) {
// We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
index b30e366d06006..292e48314ee10 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
@@ -24,6 +24,10 @@ private[spark] trait ShuffleReader[K, C] {
/** Read the combined key-values for this reduce task */
def read(): Iterator[Product2[K, C]]
- /** Close this reader */
- def stop(): Unit
+ /**
+ * Close this reader.
+ * TODO: Add this back when we make the ShuffleReader a developer API that others can implement
+ * (at which point this will likely be necessary).
+ */
+ // def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 6cf9305977a3c..e3e7434df45b0 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.util.{Failure, Success, Try}
import org.apache.spark._
import org.apache.spark.serializer.Serializer
@@ -52,21 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
+ def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
- case Some(block) => {
+ case Success(block) => {
block.asInstanceOf[Iterator[T]]
}
- case None => {
+ case Failure(e) => {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId)
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
- "Failed to get block " + blockId + ", which is not a shuffle block")
+ "Failed to get block " + blockId + ", which is not a shuffle block", e)
}
}
}
@@ -74,7 +75,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
- SparkEnv.get.blockTransferService,
+ SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 88a5f1e5ddf58..41bafabde05b9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -45,9 +45,9 @@ private[spark] class HashShuffleReader[K, C](
} else {
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
}
- } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
- throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
+ require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
+
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
}
@@ -59,14 +59,11 @@ private[spark] class HashShuffleReader[K, C](
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
- context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
sorter.iterator
case None =>
aggregatedIter
}
}
-
- /** Close this reader */
- override def stop(): Unit = ???
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 746ed33b54c00..755f17d6aa15a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -56,9 +56,8 @@ private[spark] class HashShuffleWriter[K, V](
} else {
records
}
- } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
- throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
+ require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
records
}
@@ -107,7 +106,7 @@ private[spark] class HashShuffleWriter[K, V](
writer.commitAndClose()
writer.fileSegment().length
}
- MapStatus(blockManager.blockManagerId, sizes)
+ MapStatus(blockManager.shuffleServerId, sizes)
}
private def revertWrites(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index b727438ae7e47..bda30a56d808e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -25,7 +25,7 @@ import org.apache.spark.shuffle.hash.HashShuffleReader
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager {
- private val indexShuffleBlockManager = new IndexShuffleBlockManager()
+ private val indexShuffleBlockManager = new IndexShuffleBlockManager(conf)
private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 927481b72cf4f..27496c5a289cb 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -50,9 +50,7 @@ private[spark] class SortShuffleWriter[K, V, C](
/** Write a bunch of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
if (dep.mapSideCombine) {
- if (!dep.aggregator.isDefined) {
- throw new IllegalStateException("Aggregator is empty for map-side combine")
- }
+ require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
sorter = new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
sorter.insertAll(records)
@@ -70,7 +68,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
- mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths)
+ mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
deleted file mode 100644
index 5b6d086630834..0000000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-
-/**
- * An interface for providing data for blocks.
- *
- * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer.
- *
- * Aside from unit tests, [[BlockManager]] is the main class that implements this.
- */
-private[spark] trait BlockDataProvider {
- def getBlockData(blockId: String): Either[FileSegment, ByteBuffer]
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 8df5ec6bde184..1f012941c85ab 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -53,6 +53,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
def name = "rdd_" + rddId + "_" + splitIndex
}
+// Format of the shuffle block ids (including data and index) should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData().
@DeveloperApi
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 4cc97923658bc..8bc5a1cd18b64 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,14 +17,12 @@
package org.apache.spark.storage
-import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream}
+import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.concurrent.ExecutionContext.Implicits.global
-
-import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.util.Random
@@ -35,11 +33,15 @@ import org.apache.spark._
import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.shuffle.ExternalShuffleClient
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.util._
-
private[spark] sealed trait BlockValues
private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
@@ -51,9 +53,15 @@ private[spark] class BlockResult(
readMethod: DataReadMethod.Value,
bytes: Long) {
val inputMetrics = new InputMetrics(readMethod)
- inputMetrics.bytesRead = bytes
+ inputMetrics.addBytesRead(bytes)
}
+/**
+ * Manager running on every node (driver and executors) which provides interfaces for putting and
+ * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
+ *
+ * Note that #initialize() must be called before the BlockManager is usable.
+ */
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -63,11 +71,11 @@ private[spark] class BlockManager(
val conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService)
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager,
+ numUsableCores: Int)
extends BlockDataManager with Logging {
- blockTransferService.init(this)
-
val diskBlockManager = new DiskBlockManager(this, conf)
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -87,8 +95,37 @@ private[spark] class BlockManager(
new TachyonStore(this, tachyonBlockManager)
}
- val blockManagerId = BlockManagerId(
- executorId, blockTransferService.hostName, blockTransferService.port)
+ private[spark]
+ val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
+
+ // Port used by the external shuffle service. In Yarn mode, this may be already be
+ // set through the Hadoop configuration as the server is launched in the Yarn NM.
+ private val externalShuffleServicePort =
+ Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt
+
+ // Check that we're not using external shuffle service with consolidated shuffle files.
+ if (externalShuffleServiceEnabled
+ && conf.getBoolean("spark.shuffle.consolidateFiles", false)
+ && shuffleManager.isInstanceOf[HashShuffleManager]) {
+ throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated"
+ + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or "
+ + " switch to sort-based shuffle.")
+ }
+
+ var blockManagerId: BlockManagerId = _
+
+ // Address of the server that serves this executor's shuffle files. This is either an external
+ // service, or just our own Executor's BlockManager.
+ private[spark] var shuffleServerId: BlockManagerId = _
+
+ // Client to read other executors' shuffle files. This is either an external service, or just the
+ // standard BlockTransferService to directly connect to other Executors.
+ private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
+ val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
+ new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
+ } else {
+ blockTransferService
+ }
// Whether to compress broadcast variables that are stored
private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
@@ -118,8 +155,6 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
- initialize()
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -138,17 +173,66 @@ private[spark] class BlockManager(
conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService) = {
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager,
+ numUsableCores: Int) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
- conf, mapOutputTracker, shuffleManager, blockTransferService)
+ conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores)
}
/**
- * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor.
+ * Initializes the BlockManager with the given appId. This is not performed in the constructor as
+ * the appId may not be known at BlockManager instantiation time (in particular for the driver,
+ * where it is only learned after registration with the TaskScheduler).
+ *
+ * This method initializes the BlockTransferService and ShuffleClient, registers with the
+ * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+ * service if configured.
*/
- private def initialize(): Unit = {
+ def initialize(appId: String): Unit = {
+ blockTransferService.init(this)
+ shuffleClient.init(appId)
+
+ blockManagerId = BlockManagerId(
+ executorId, blockTransferService.hostName, blockTransferService.port)
+
+ shuffleServerId = if (externalShuffleServiceEnabled) {
+ BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+ } else {
+ blockManagerId
+ }
+
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+
+ // Register Executors' configuration with the local shuffle service, if one should exist.
+ if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
+ registerWithExternalShuffleServer()
+ }
+ }
+
+ private def registerWithExternalShuffleServer() {
+ logInfo("Registering executor with local external shuffle service.")
+ val shuffleConfig = new ExecutorShuffleInfo(
+ diskBlockManager.localDirs.map(_.toString),
+ diskBlockManager.subDirsPerLocalDir,
+ shuffleManager.getClass.getName)
+
+ val MAX_ATTEMPTS = 3
+ val SLEEP_TIME_SECS = 5
+
+ for (i <- 1 to MAX_ATTEMPTS) {
+ try {
+ // Synchronous and will throw an exception if we cannot connect.
+ shuffleClient.asInstanceOf[ExternalShuffleClient].registerWithShuffleServer(
+ shuffleServerId.host, shuffleServerId.port, shuffleServerId.executorId, shuffleConfig)
+ return
+ } catch {
+ case e: Exception if i < MAX_ATTEMPTS =>
+ logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}"
+ + s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
+ Thread.sleep(SLEEP_TIME_SECS * 1000)
+ }
+ }
}
/**
@@ -212,21 +296,20 @@ private[spark] class BlockManager(
}
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- override def getBlockData(blockId: String): Option[ManagedBuffer] = {
- val bid = BlockId(blockId)
- if (bid.isShuffle) {
- Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]))
+ override def getBlockData(blockId: BlockId): ManagedBuffer = {
+ if (blockId.isShuffle) {
+ shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
- val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
+ .asInstanceOf[Option[ByteBuffer]]
if (blockBytesOpt.isDefined) {
val buffer = blockBytesOpt.get
- Some(new NioByteBufferManagedBuffer(buffer))
+ new NioManagedBuffer(buffer)
} else {
- None
+ throw new BlockNotFoundException(blockId.toString)
}
}
}
@@ -234,8 +317,8 @@ private[spark] class BlockManager(
/**
* Put the block locally, using the given storage level.
*/
- override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = {
- putBytes(BlockId(blockId), data.nioByteBuffer(), level)
+ override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = {
+ putBytes(blockId, data.nioByteBuffer(), level)
}
/**
@@ -340,17 +423,6 @@ private[spark] class BlockManager(
locations
}
- /**
- * A short-circuited method to get blocks directly from disk. This is used for getting
- * shuffle blocks. It is safe to do so without a lock on block info since disk store
- * never deletes (recent) items.
- */
- def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
- val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
- val is = wrapForCompression(blockId, buf.inputStream())
- Some(serializer.newInstance().deserializeStream(is).asIterator)
- }
-
/**
* Get block from local block manager.
*/
@@ -520,7 +592,7 @@ private[spark] class BlockManager(
for (loc <- locations) {
logDebug(s"Getting remote block $blockId from $loc")
val data = blockTransferService.fetchBlockSync(
- loc.host, loc.port, blockId.toString).nioByteBuffer()
+ loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
if (data != null) {
if (asBlockResult) {
@@ -869,9 +941,9 @@ private[spark] class BlockManager(
data.rewind()
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
blockTransferService.uploadBlockSync(
- peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
- logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %d ms"
- .format((System.currentTimeMillis - onePeerStartTime)))
+ peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel)
+ logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
+ .format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
peersForReplication -= peer
replicationFailed = false
@@ -941,8 +1013,10 @@ private[spark] class BlockManager(
// If we get here, the block write failed.
logWarning(s"Block $blockId was marked as failure. Nothing to drop")
return None
+ } else if (blockInfo.get(blockId).isEmpty) {
+ logWarning(s"Block $blockId was already dropped.")
+ return None
}
-
var blockIsUpdated = false
val level = info.level
@@ -1126,7 +1200,11 @@ private[spark] class BlockManager(
}
def stop(): Unit = {
- blockTransferService.stop()
+ blockTransferService.close()
+ if (shuffleClient ne blockTransferService) {
+ // Closing should be idempotent, but maybe not for the NioBlockTransferService.
+ shuffleClient.close()
+ }
diskBlockManager.stop()
actorSystem.stop(slaveActor)
blockInfo.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index 142285094342c..b177a59c721df 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
@@ -59,15 +60,15 @@ class BlockManagerId private (
def port: Int = port_
- def isDriver: Boolean = (executorId == "")
+ def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER }
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeUTF(executorId_)
out.writeUTF(host_)
out.writeInt(port_)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
executorId_ = in.readUTF()
host_ = in.readUTF()
port_ = in.readInt()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index d08e1419e3e41..b63c7f191155c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -88,6 +88,10 @@ class BlockManagerMaster(
askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
+ def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
+ }
+
/**
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 088f06e389d83..64133464d8daa 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -52,8 +52,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
private val akkaTimeout = AkkaUtils.askTimeout(conf)
- val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs",
- math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000))
+ val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120 * 1000)
val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)
@@ -73,9 +72,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case UpdateBlockInfo(
blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) =>
- // TODO: Ideally we want to handle all the message replies in receive instead of in the
- // individual private methods.
- updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)
+ sender ! updateBlockInfo(
+ blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)
case GetLocations(blockId) =>
sender ! getLocations(blockId)
@@ -86,6 +84,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetPeers(blockManagerId) =>
sender ! getPeers(blockManagerId)
+ case GetActorSystemHostPortForExecutor(executorId) =>
+ sender ! getActorSystemHostPortForExecutor(executorId)
+
case GetMemoryStatus =>
sender ! memoryStatus
@@ -203,6 +204,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}
}
listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId))
+ logInfo(s"Removing block manager $blockManagerId")
}
private def expireDeadHosts() {
@@ -327,20 +329,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
val time = System.currentTimeMillis()
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
- case Some(manager) =>
- // A block manager of the same executor already exists.
- // This should never happen. Let's just quit.
- logError("Got two different block manager registrations on " + id.executorId)
- System.exit(1)
+ case Some(oldId) =>
+ // A block manager of the same executor already exists, so remove it (assumed dead)
+ logError("Got two different block manager registrations on same executor - "
+ + s" will replace old one $oldId with new one $id")
+ removeExecutor(id.executorId)
case None =>
- blockManagerIdByExecutor(id.executorId) = id
}
-
- logInfo("Registering block manager %s with %s RAM".format(
- id.hostPort, Utils.bytesToString(maxMemSize)))
-
- blockManagerInfo(id) =
- new BlockManagerInfo(id, time, maxMemSize, slaveActor)
+ logInfo("Registering block manager %s with %s RAM, %s".format(
+ id.hostPort, Utils.bytesToString(maxMemSize), id))
+
+ blockManagerIdByExecutor(id.executorId) = id
+
+ blockManagerInfo(id) = new BlockManagerInfo(
+ id, System.currentTimeMillis(), maxMemSize, slaveActor)
}
listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
}
@@ -351,23 +353,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
storageLevel: StorageLevel,
memSize: Long,
diskSize: Long,
- tachyonSize: Long) {
+ tachyonSize: Long): Boolean = {
if (!blockManagerInfo.contains(blockManagerId)) {
if (blockManagerId.isDriver && !isLocal) {
// We intentionally do not register the master (except in local mode),
// so we should not indicate failure.
- sender ! true
+ return true
} else {
- sender ! false
+ return false
}
- return
}
if (blockId == null) {
blockManagerInfo(blockManagerId).updateLastSeenMs()
- sender ! true
- return
+ return true
}
blockManagerInfo(blockManagerId).updateBlockInfo(
@@ -391,7 +391,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
if (locations.size == 0) {
blockLocations.remove(blockId)
}
- sender ! true
+ true
}
private def getLocations(blockId: BlockId): Seq[BlockManagerId] = {
@@ -411,6 +411,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
Seq.empty
}
}
+
+ /**
+ * Returns the hostname and port of an executor's actor system, based on the Akka address of its
+ * BlockManagerSlaveActor.
+ */
+ private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ for (
+ blockManagerId <- blockManagerIdByExecutor.get(executorId);
+ info <- blockManagerInfo.get(blockManagerId);
+ host <- info.slaveActor.path.address.host;
+ port <- info.slaveActor.path.address.port
+ ) yield {
+ (host, port)
+ }
+ }
}
@DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 3db5dd9774ae8..3f32099d08cc9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -21,6 +21,8 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import akka.actor.ActorRef
+import org.apache.spark.util.Utils
+
private[spark] object BlockManagerMessages {
//////////////////////////////////////////////////////////////////////////////////
// Messages from the master to slaves.
@@ -65,7 +67,7 @@ private[spark] object BlockManagerMessages {
def this() = this(null, null, null, 0, 0, 0) // For deserialization only
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
blockManagerId.writeExternal(out)
out.writeUTF(blockId.name)
storageLevel.writeExternal(out)
@@ -74,7 +76,7 @@ private[spark] object BlockManagerMessages {
out.writeLong(tachyonSize)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
blockManagerId = BlockManagerId(in)
blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in)
@@ -90,6 +92,8 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+ case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
case object StopBlockManagerMaster extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
index 9ef453605f4f1..81f5f2d31dbd8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
@@ -17,5 +17,4 @@
package org.apache.spark.storage
-
class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 9c469370ffe1f..3198d766fca37 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -160,14 +160,14 @@ private[spark] class DiskBlockObjectWriter(
}
finalPosition = file.length()
// In certain compression codecs, more bytes are written after close() is called
- writeMetrics.shuffleBytesWritten += (finalPosition - reportedPosition)
+ writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
}
// Discard current writes. We do this by flushing the outstanding writes and then
// truncating the file to its initial position.
override def revertPartialWritesAndClose() {
try {
- writeMetrics.shuffleBytesWritten -= (reportedPosition - initialPosition)
+ writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
if (initialized) {
objOut.flush()
@@ -212,14 +212,14 @@ private[spark] class DiskBlockObjectWriter(
*/
private def updateBytesWritten() {
val pos = channel.position()
- writeMetrics.shuffleBytesWritten += (pos - reportedPosition)
+ writeMetrics.incShuffleBytesWritten(pos - reportedPosition)
reportedPosition = pos
}
private def callWithTiming(f: => Unit) = {
val start = System.nanoTime()
f
- writeMetrics.shuffleWriteTime += (System.nanoTime() - start)
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
}
// For testing
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 6633a1db57e59..53eaedacbf291 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -17,9 +17,8 @@
package org.apache.spark.storage
-import java.io.File
-import java.text.SimpleDateFormat
-import java.util.{Date, Random, UUID}
+import java.util.UUID
+import java.io.{IOException, File}
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.executor.ExecutorExitCode
@@ -37,13 +36,13 @@ import org.apache.spark.util.Utils
private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf)
extends Logging {
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
+ private[spark]
+ val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
* having really large inodes at the top level. */
- val localDirs: Array[File] = createLocalDirs(conf)
+ private[spark] val localDirs: Array[File] = createLocalDirs(conf)
if (localDirs.isEmpty) {
logError("Failed to create any local dir.")
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
@@ -52,6 +51,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
addShutdownHook()
+ /** Looks up a file by hashing it into one of our local subdirectories. */
+ // This method should be kept in sync with
+ // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile().
def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
@@ -67,7 +69,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
old
} else {
val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
- newDir.mkdir()
+ if (!newDir.exists() && !newDir.mkdir()) {
+ throw new IOException(s"Failed to create local dir in $newDir.")
+ }
subDirs(dirId)(subDirId) = newDir
newDir
}
@@ -117,39 +121,20 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
private def createLocalDirs(conf: SparkConf): Array[File] = {
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, s"spark-local-$localDirId")
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning(s"Attempt $tries to create local dir $localDir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create local dir in $rootDir." +
- " Ignoring this directory.")
- None
- } else {
+ try {
+ val localDir = Utils.createDirectory(rootDir, "blockmgr")
logInfo(s"Created local directory at $localDir")
Some(localDir)
+ } catch {
+ case e: IOException =>
+ logError(s"Failed to create local dir in $rootDir. Ignoring this directory.", e)
+ None
}
}
}
private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run(): Unit = Utils.logUncaughtExceptions {
logDebug("Shutdown hook called")
@@ -160,13 +145,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
/** Cleanup local dirs and stop shuffle sender. */
private[spark] def stop() {
- localDirs.foreach { localDir =>
- if (localDir.isDirectory() && localDir.exists()) {
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case e: Exception =>
- logError(s"Exception while deleting local spark dir: $localDir", e)
+ // Only perform cleanup if an external service is not serving our shuffle files.
+ if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) {
+ localDirs.foreach { localDir =>
+ if (localDir.isDirectory() && localDir.exists()) {
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case e: Exception =>
+ logError(s"Exception while deleting local spark dir: $localDir", e)
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index bac459e835a3f..61ef5ff168791 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, RandomAccessFile}
+import java.io.{IOException, File, FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode
@@ -31,7 +31,8 @@ import org.apache.spark.util.Utils
private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- val minMemoryMapBytes = blockManager.conf.getLong("spark.storage.memoryMapThreshold", 2 * 4096L)
+ val minMemoryMapBytes = blockManager.conf.getLong(
+ "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L)
override def getSize(blockId: BlockId): Long = {
diskManager.getFile(blockId.name).length
@@ -110,7 +111,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
// For small files, directly read rather than memory map
if (length < minMemoryMapBytes) {
val buf = ByteBuffer.allocate(length.toInt)
- channel.read(buf, offset)
+ channel.position(offset)
+ while (buf.remaining() != 0) {
+ if (channel.read(buf) == -1) {
+ throw new IOException("Reached EOF before filling buffer\n" +
+ s"offset=$offset\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
+ }
+ }
buf.flip()
Some(buf)
} else {
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index edbc729c17ade..71305a46bf570 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -56,6 +56,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
(maxMemory * unrollFraction).toLong
}
+ // Initial memory to request before unrolling any block
+ private val unrollMemoryThreshold: Long =
+ conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024)
+
+ if (maxMemory < unrollMemoryThreshold) {
+ logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " +
+ s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " +
+ s"memory. Please configure Spark with more memory.")
+ }
+
logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory)))
/** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */
@@ -213,7 +223,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// Whether there is still enough memory for us to continue unrolling this block
var keepUnrolling = true
// Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing.
- val initialMemoryThreshold = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024)
+ val initialMemoryThreshold = unrollMemoryThreshold
// How often to check whether we need to request more memory
val memoryCheckPeriod = 16
// Memory currently reserved by this thread for this particular unrolling operation
@@ -228,6 +238,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// Request enough memory to begin unrolling
keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold)
+ if (!keepUnrolling) {
+ logWarning(s"Failed to reserve initial memory threshold of " +
+ s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
+ }
+
// Unroll this block safely, checking whether we have exceeded our threshold periodically
try {
while (values.hasNext && keepUnrolling) {
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 71b276b5f18e4..ab9ee4f0096bf 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,17 +17,18 @@
package org.apache.spark.storage
+import java.io.{InputStream, IOException}
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Queue
+import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
+import scala.util.{Failure, Success, Try}
-import org.apache.spark.{TaskContext, Logging}
-import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService}
+import org.apache.spark.{Logging, TaskContext}
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
-
+import org.apache.spark.util.{CompletionIterator, Utils}
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -40,8 +41,8 @@ import org.apache.spark.util.Utils
* using too much memory.
*
* @param context [[TaskContext]], used for metrics update
- * @param blockTransferService [[BlockTransferService]] for fetching remote blocks
- * @param blockManager [[BlockManager]] for reading local blocks
+ * @param shuffleClient [[ShuffleClient]] for fetching remote blocks
+ * @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
@@ -51,12 +52,12 @@ import org.apache.spark.util.Utils
private[spark]
final class ShuffleBlockFetcherIterator(
context: TaskContext,
- blockTransferService: BlockTransferService,
+ shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer,
maxBytesInFlight: Long)
- extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
+ extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
import ShuffleBlockFetcherIterator._
@@ -88,17 +89,53 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
+ /**
+ * Current [[FetchResult]] being processed. We track this so we can release the current buffer
+ * in case of a runtime exception when processing the current buffer.
+ */
+ @volatile private[this] var currentResult: FetchResult = null
+
+ /**
+ * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ * the number of bytes in flight is limited to maxBytesInFlight.
+ */
private[this] val fetchRequests = new Queue[FetchRequest]
- // Current bytes in flight from our requests
+ /** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+ /**
+ * Whether the iterator is still active. If isZombie is true, the callback interface will no
+ * longer place fetched blocks into [[results]].
+ */
+ @volatile private[this] var isZombie = false
+
initialize()
+ /**
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+ */
+ private[this] def cleanup() {
+ isZombie = true
+ // Release the current buffer if necessary
+ currentResult match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
+ }
+
+ // Release buffers in the results queue
+ val iter = results.iterator()
+ while (iter.hasNext) {
+ val result = iter.next()
+ result match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
+ }
+ }
+ }
+
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
@@ -108,26 +145,26 @@ final class ShuffleBlockFetcherIterator(
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
- blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds,
+ val address = req.address
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- results.put(new FetchResult(BlockId(blockId), sizeMap(blockId),
- () => serializer.newInstance().deserializeStream(
- blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator
- ))
- shuffleMetrics.remoteBytesRead += data.size
- shuffleMetrics.remoteBlocksFetched += 1
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
+ // Only add the buffer to results queue if the iterator is not zombie,
+ // i.e. cleanup() has not been called yet.
+ if (!isZombie) {
+ // Increment the ref count because we need to pass this to a different thread.
+ // This needs to be released after use.
+ buf.retain()
+ results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ shuffleMetrics.incRemoteBlocksFetched(1)
+ }
+ logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
- override def onBlockFetchFailure(e: Throwable): Unit = {
+ override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- // Note that there is a chance that some blocks have been fetched successfully, but we
- // still add them to the failed queue. This is fine because when the caller see a
- // FetchFailedException, it is going to fail the entire task anyway.
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
+ results.put(new FailureFetchResult(BlockId(blockId), e))
}
}
)
@@ -138,7 +175,7 @@ final class ShuffleBlockFetcherIterator(
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
+ logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
@@ -148,7 +185,7 @@ final class ShuffleBlockFetcherIterator(
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
- if (address == blockManager.blockManagerId) {
+ if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
@@ -185,26 +222,34 @@ final class ShuffleBlockFetcherIterator(
remoteRequests
}
+ /**
+ * Fetch the local blocks while we are fetching remote blocks. This is ok because
+ * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we
+ * track in-memory are the ManagedBuffer references themselves.
+ */
private[this] def fetchLocalBlocks() {
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- for (id <- localBlocks) {
+ val iter = localBlocks.iterator
+ while (iter.hasNext) {
+ val blockId = iter.next()
try {
- shuffleMetrics.localBlocksFetched += 1
- results.put(new FetchResult(
- id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get))
- logDebug("Got local block " + id)
+ val buf = blockManager.getBlockData(blockId)
+ shuffleMetrics.incLocalBlocksFetched(1)
+ buf.retain()
+ results.put(new SuccessFetchResult(blockId, 0, buf))
} catch {
case e: Exception =>
+ // If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FetchResult(id, -1, null))
+ results.put(new FailureFetchResult(blockId, e))
return
}
}
}
private[this] def initialize(): Unit = {
+ // Add a task completion callback (called in both success case and failure case) to cleanup.
+ context.addTaskCompletionListener(_ => cleanup())
+
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
@@ -221,26 +266,49 @@ final class ShuffleBlockFetcherIterator(
// Get Local Blocks
fetchLocalBlocks()
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
- override def next(): (BlockId, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Try[Iterator[Any]]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
- val result = results.take()
+ currentResult = results.take()
+ val result = currentResult
val stopFetchWait = System.currentTimeMillis()
- shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
- if (!result.failed) {
- bytesInFlight -= result.size
+ shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
+
+ result match {
+ case SuccessFetchResult(_, size, _) => bytesInFlight -= size
+ case _ =>
}
// Send fetch requests up to maxBytesInFlight
while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
+
+ val iteratorTry: Try[Iterator[Any]] = result match {
+ case FailureFetchResult(_, e) =>
+ Failure(e)
+ case SuccessFetchResult(blockId, _, buf) =>
+ // There is a chance that createInputStream can fail (e.g. fetching a local file that does
+ // not exist, SPARK-4085). In that case, we should propagate the right exception so
+ // the scheduler gets a FetchFailedException.
+ Try(buf.createInputStream()).map { is0 =>
+ val is = blockManager.wrapForCompression(blockId, is0)
+ val iter = serializer.newInstance().deserializeStream(is).asIterator
+ CompletionIterator[Any, Iterator[Any]](iter, {
+ // Once the iterator is exhausted, release the buffer and set currentResult to null
+ // so we don't release it again in cleanup.
+ currentResult = null
+ buf.release()
+ })
+ }
+ }
+
+ (result.blockId, iteratorTry)
}
}
@@ -254,18 +322,35 @@ object ShuffleBlockFetcherIterator {
* @param blocks Sequence of tuple, where the first element is the block id,
* and the second element is the estimated size, used to calculate bytesInFlight.
*/
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
+ case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
/**
- * Result of a fetch from a remote block. A failure is represented as size == -1.
+ * Result of a fetch from a remote block.
+ */
+ private[storage] sealed trait FetchResult {
+ val blockId: BlockId
+ }
+
+ /**
+ * Result of a fetch from a remote block successfully.
* @param blockId block id
* @param size estimated size of the block, used to calculate bytesInFlight.
* Note that this is NOT the exact bytes.
- * @param deserialize closure to return the result in the form of an Iterator.
+ * @param buf [[ManagedBuffer]] for the content.
*/
- class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
+ private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer)
+ extends FetchResult {
+ require(buf != null)
+ require(size >= 0)
}
+
+ /**
+ * Result of a fetch from a remote block unsuccessfully.
+ * @param blockId block id
+ * @param e the failure exception
+ */
+ private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable)
+ extends FetchResult
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 1e35abaab5353..e5e1cf5a69a19 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -18,8 +18,10 @@
package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
+import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -97,12 +99,12 @@ class StorageLevel private(
ret
}
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeByte(toInt)
out.writeByte(_replication)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val flags = in.readByte()
_useDisk = (flags & 8) != 0
_useMemory = (flags & 4) != 0
@@ -219,8 +221,7 @@ object StorageLevel {
getCachedStorageLevel(obj)
}
- private[spark] val storageLevelCache =
- new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
+ private[spark] val storageLevelCache = new ConcurrentHashMap[StorageLevel, StorageLevel]()
private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
storageLevelCache.putIfAbsent(level, level)
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index d9066f766476e..def49e80a3605 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import scala.collection.mutable
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
@@ -59,10 +60,9 @@ class StorageStatusListener extends SparkListener {
val info = taskEnd.taskInfo
val metrics = taskEnd.taskMetrics
if (info != null && metrics != null) {
- val execId = formatExecutorId(info.executorId)
val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
if (updatedBlocks.length > 0) {
- updateStorageStatus(execId, updatedBlocks)
+ updateStorageStatus(info.executorId, updatedBlocks)
}
}
}
@@ -88,13 +88,4 @@ class StorageStatusListener extends SparkListener {
}
}
- /**
- * In the local mode, there is a discrepancy between the executor ID according to the
- * task ("localhost") and that according to SparkEnv (""). In the UI, this
- * results in duplicate rows for the same executor. Thus, in this mode, we aggregate
- * these two rows and use the executor ID of "" to be consistent.
- */
- def formatExecutorId(execId: String): String = {
- if (execId == "localhost") "" else execId
- }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
index 6908a59a79e60..af873034215a9 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -148,6 +148,7 @@ private[spark] class TachyonBlockManager(
logError("Exception while deleting tachyon spark dir: " + tachyonDir, e)
}
}
+ client.close()
}
})
}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
index 932b5616043b4..233d1e2b7c616 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io.IOException
import java.nio.ByteBuffer
+import com.google.common.io.ByteStreams
import tachyon.client.{ReadType, WriteType}
import org.apache.spark.Logging
@@ -105,25 +106,19 @@ private[spark] class TachyonStore(
return None
}
val is = file.getInStream(ReadType.CACHE)
- var buffer: ByteBuffer = null
+ assert (is != null)
try {
- if (is != null) {
- val size = file.length
- val bs = new Array[Byte](size.asInstanceOf[Int])
- val fetchSize = is.read(bs, 0, size.asInstanceOf[Int])
- buffer = ByteBuffer.wrap(bs)
- if (fetchSize != size) {
- logWarning(s"Failed to fetch the block $blockId from Tachyon: Size $size " +
- s"is not equal to fetched size $fetchSize")
- return None
- }
- }
+ val size = file.length
+ val bs = new Array[Byte](size.asInstanceOf[Int])
+ ByteStreams.readFully(is, bs)
+ Some(ByteBuffer.wrap(bs))
} catch {
case ioe: IOException =>
logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe)
- return None
+ None
+ } finally {
+ is.close()
}
- Some(buffer)
}
override def contains(blockId: BlockId): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
new file mode 100644
index 0000000000000..27ba9e18237b5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import java.util.{Timer, TimerTask}
+
+import org.apache.spark._
+
+/**
+ * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the
+ * status of active stages from `sc.statusTracker` periodically, the progress bar will be showed
+ * up after the stage has ran at least 500ms. If multiple stages run in the same time, the status
+ * of them will be combined together, showed in one line.
+ */
+private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
+
+ // Carrige return
+ val CR = '\r'
+ // Update period of progress bar, in milliseconds
+ val UPDATE_PERIOD = 200L
+ // Delay to show up a progress bar, in milliseconds
+ val FIRST_DELAY = 500L
+
+ // The width of terminal
+ val TerminalWidth = if (!sys.env.getOrElse("COLUMNS", "").isEmpty) {
+ sys.env.get("COLUMNS").get.toInt
+ } else {
+ 80
+ }
+
+ var lastFinishTime = 0L
+ var lastUpdateTime = 0L
+ var lastProgressBar = ""
+
+ // Schedule a refresh thread to run periodically
+ private val timer = new Timer("refresh progress", true)
+ timer.schedule(new TimerTask{
+ override def run() {
+ refresh()
+ }
+ }, FIRST_DELAY, UPDATE_PERIOD)
+
+ /**
+ * Try to refresh the progress bar in every cycle
+ */
+ private def refresh(): Unit = synchronized {
+ val now = System.currentTimeMillis()
+ if (now - lastFinishTime < FIRST_DELAY) {
+ return
+ }
+ val stageIds = sc.statusTracker.getActiveStageIds()
+ val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1)
+ .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId())
+ if (stages.size > 0) {
+ show(now, stages.take(3)) // display at most 3 stages in same time
+ }
+ }
+
+ /**
+ * Show progress bar in console. The progress bar is displayed in the next line
+ * after your last output, keeps overwriting itself to hold in one line. The logging will follow
+ * the progress bar, then progress bar will be showed in next line without overwrite logs.
+ */
+ private def show(now: Long, stages: Seq[SparkStageInfo]) {
+ val width = TerminalWidth / stages.size
+ val bar = stages.map { s =>
+ val total = s.numTasks()
+ val header = s"[Stage ${s.stageId()}:"
+ val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]"
+ val w = width - header.size - tailer.size
+ val bar = if (w > 0) {
+ val percent = w * s.numCompletedTasks() / total
+ (0 until w).map { i =>
+ if (i < percent) "=" else if (i == percent) ">" else " "
+ }.mkString("")
+ } else {
+ ""
+ }
+ header + bar + tailer
+ }.mkString("")
+
+ // only refresh if it's changed of after 1 minute (or the ssh connection will be closed
+ // after idle some time)
+ if (bar != lastProgressBar || now - lastUpdateTime > 60 * 1000L) {
+ System.err.print(CR + bar)
+ lastUpdateTime = now
+ }
+ lastProgressBar = bar
+ }
+
+ /**
+ * Clear the progress bar if showed.
+ */
+ private def clear() {
+ if (!lastProgressBar.isEmpty) {
+ System.err.printf(CR + " " * TerminalWidth + CR)
+ lastProgressBar = ""
+ }
+ }
+
+ /**
+ * Mark all the stages as finished, clear the progress bar if showed, then the progress will not
+ * interweave with output of jobs.
+ */
+ def finishAll(): Unit = synchronized {
+ clear()
+ lastFinishTime = System.currentTimeMillis()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 2a27d49d2de05..88fed833f922d 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -201,7 +201,7 @@ private[spark] object JettyUtils extends Logging {
}
}
- val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName)
+ val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
ServerInfo(server, boundPort, collection)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index cccd59d122a92..0c24ad2760e08 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -21,60 +21,44 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.ui.env.EnvironmentTab
-import org.apache.spark.ui.exec.ExecutorsTab
-import org.apache.spark.ui.jobs.JobProgressTab
-import org.apache.spark.ui.storage.StorageTab
+import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab}
+import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab}
+import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab}
+import org.apache.spark.ui.storage.{StorageListener, StorageTab}
/**
* Top level user interface for a Spark application.
*/
-private[spark] class SparkUI(
- val sc: SparkContext,
+private[spark] class SparkUI private (
+ val sc: Option[SparkContext],
val conf: SparkConf,
val securityManager: SecurityManager,
- val listenerBus: SparkListenerBus,
+ val environmentListener: EnvironmentListener,
+ val storageStatusListener: StorageStatusListener,
+ val executorsListener: ExecutorsListener,
+ val jobProgressListener: JobProgressListener,
+ val storageListener: StorageListener,
var appName: String,
- val basePath: String = "")
+ val basePath: String)
extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI")
with Logging {
- def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName)
- def this(conf: SparkConf, listenerBus: SparkListenerBus, appName: String, basePath: String) =
- this(null, conf, new SecurityManager(conf), listenerBus, appName, basePath)
-
- def this(
- conf: SparkConf,
- securityManager: SecurityManager,
- listenerBus: SparkListenerBus,
- appName: String,
- basePath: String) =
- this(null, conf, securityManager, listenerBus, appName, basePath)
-
- // If SparkContext is not provided, assume the associated application is not live
- val live = sc != null
-
- // Maintain executor storage status through Spark events
- val storageStatusListener = new StorageStatusListener
-
- initialize()
+ val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false)
/** Initialize all components of the server. */
def initialize() {
- listenerBus.addListener(storageStatusListener)
- val jobProgressTab = new JobProgressTab(this)
- attachTab(jobProgressTab)
+ attachTab(new JobsTab(this))
+ val stagesTab = new StagesTab(this)
+ attachTab(stagesTab)
attachTab(new StorageTab(this))
attachTab(new EnvironmentTab(this))
attachTab(new ExecutorsTab(this))
attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"))
- attachHandler(createRedirectHandler("/", "/stages", basePath = basePath))
+ attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath))
attachHandler(
- createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest))
- if (live) {
- sc.env.metricsSystem.getServletHandlers.foreach(attachHandler)
- }
+ createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest))
}
+ initialize()
def getAppName = appName
@@ -83,11 +67,6 @@ private[spark] class SparkUI(
appName = name
}
- /** Register the given listener with the listener bus. */
- def registerListener(listener: SparkListener) {
- listenerBus.addListener(listener)
- }
-
/** Stop the server behind this web interface. Only valid after bind(). */
override def stop() {
super.stop()
@@ -116,4 +95,60 @@ private[spark] object SparkUI {
def getUIPort(conf: SparkConf): Int = {
conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT)
}
+
+ def createLiveUI(
+ sc: SparkContext,
+ conf: SparkConf,
+ listenerBus: SparkListenerBus,
+ jobProgressListener: JobProgressListener,
+ securityManager: SecurityManager,
+ appName: String): SparkUI = {
+ create(Some(sc), conf, listenerBus, securityManager, appName,
+ jobProgressListener = Some(jobProgressListener))
+ }
+
+ def createHistoryUI(
+ conf: SparkConf,
+ listenerBus: SparkListenerBus,
+ securityManager: SecurityManager,
+ appName: String,
+ basePath: String): SparkUI = {
+ create(None, conf, listenerBus, securityManager, appName, basePath)
+ }
+
+ /**
+ * Create a new Spark UI.
+ *
+ * @param sc optional SparkContext; this can be None when reconstituting a UI from event logs.
+ * @param jobProgressListener if supplied, this JobProgressListener will be used; otherwise, the
+ * web UI will create and register its own JobProgressListener.
+ */
+ private def create(
+ sc: Option[SparkContext],
+ conf: SparkConf,
+ listenerBus: SparkListenerBus,
+ securityManager: SecurityManager,
+ appName: String,
+ basePath: String = "",
+ jobProgressListener: Option[JobProgressListener] = None): SparkUI = {
+
+ val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse {
+ val listener = new JobProgressListener(conf)
+ listenerBus.addListener(listener)
+ listener
+ }
+
+ val environmentListener = new EnvironmentListener
+ val storageStatusListener = new StorageStatusListener
+ val executorsListener = new ExecutorsListener(storageStatusListener)
+ val storageListener = new StorageListener(storageStatusListener)
+
+ listenerBus.addListener(environmentListener)
+ listenerBus.addListener(storageStatusListener)
+ listenerBus.addListener(executorsListener)
+ listenerBus.addListener(storageListener)
+
+ new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener,
+ executorsListener, _jobProgressListener, storageListener, appName, basePath)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 9ced9b8107ebf..4307029d44fbb 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -24,11 +24,30 @@ private[spark] object ToolTips {
scheduler delay is large, consider decreasing the size of tasks or decreasing the size
of task results."""
+ val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor."
+
+ val SHUFFLE_READ_BLOCKED_TIME =
+ "Time that the task spent blocked waiting for shuffle data to be read from remote machines."
+
val INPUT = "Bytes read from Hadoop or from Spark storage."
+ val OUTPUT = "Bytes written to Hadoop."
+
val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage."
val SHUFFLE_READ =
"""Bytes read from remote executors. Typically less than shuffle write bytes
because this does not include shuffle data read locally."""
+
+ val GETTING_RESULT_TIME =
+ """Time that the driver spends fetching task results from workers. If this is large, consider
+ decreasing the amount of data returned from each task."""
+
+ val RESULT_SERIALIZATION_TIME =
+ """Time spent serializing the task result on the executor before sending it back to the
+ driver."""
+
+ val GC_TIME =
+ """Time that the executor spent paused for Java garbage collection while the task was
+ running."""
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 32e6b15bb0999..b5022fe853c49 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -20,13 +20,14 @@ package org.apache.spark.ui
import java.text.SimpleDateFormat
import java.util.{Locale, Date}
-import scala.xml.Node
+import scala.xml.{Node, Text}
import org.apache.spark.Logging
/** Utility functions for generating XML pages with spark content. */
private[spark] object UIUtils extends Logging {
- val TABLE_CLASS = "table table-bordered table-striped table-condensed sortable"
+ val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable"
+ val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
@@ -160,6 +161,8 @@ private[spark] object UIUtils extends Logging {
+
+
}
/** Returns a spark page with correctly formatted headers */
@@ -167,15 +170,21 @@ private[spark] object UIUtils extends Logging {
title: String,
content: => Seq[Node],
activeTab: SparkUITab,
- refreshInterval: Option[Int] = None): Seq[Node] = {
+ refreshInterval: Option[Int] = None,
+ helpText: Option[String] = None): Seq[Node] = {
val appName = activeTab.appName
val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..."
val header = activeTab.headerTabs.map { tab =>
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
index 18d2b5075aa08..fc1844600f1cb 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala
@@ -20,23 +20,25 @@ package org.apache.spark.ui
import scala.util.Random
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.scheduler.SchedulingMode
+// scalastyle:off
/**
* Continuously generates jobs that expose various features of the WebUI (internal testing tool).
*
- * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]
+ * Usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR] [#job set (4 jobs per set)]
*/
+// scalastyle:on
private[spark] object UIWorkloadGenerator {
val NUM_PARTITIONS = 100
val INTER_JOB_WAIT_MS = 5000
def main(args: Array[String]) {
- if (args.length < 2) {
+ if (args.length < 3) {
println(
- "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]")
+ "usage: ./bin/spark-class org.apache.spark.ui.UIWorkloadGenerator " +
+ "[master] [FIFO|FAIR] [#job set (4 jobs per set)]")
System.exit(1)
}
@@ -46,6 +48,7 @@ private[spark] object UIWorkloadGenerator {
if (schedulingMode == SchedulingMode.FAIR) {
conf.set("spark.scheduler.mode", "FAIR")
}
+ val nJobSet = args(2).toInt
val sc = new SparkContext(conf)
def setProperties(s: String) = {
@@ -85,7 +88,7 @@ private[spark] object UIWorkloadGenerator {
("Job with delays", baseData.map(x => Thread.sleep(100)).count)
)
- while (true) {
+ (1 to nJobSet).foreach { _ =>
for ((desc, job) <- jobs) {
new Thread {
override def run() {
@@ -102,5 +105,6 @@ private[spark] object UIWorkloadGenerator {
Thread.sleep(INTER_JOB_WAIT_MS)
}
}
+ sc.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index 5d88ca403a674..9be65a4a39a09 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -82,7 +82,7 @@ private[spark] abstract class WebUI(
}
/** Detach a handler from this UI. */
- def detachHandler(handler: ServletContextHandler) {
+ protected def detachHandler(handler: ServletContextHandler) {
handlers -= handler
serverInfo.foreach { info =>
info.rootHandler.removeHandler(handler)
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
index 0d158fbe638d3..f62260c6f6e1d 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
@@ -22,10 +22,8 @@ import org.apache.spark.scheduler._
import org.apache.spark.ui._
private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") {
- val listener = new EnvironmentListener
-
+ val listener = parent.environmentListener
attachPage(new EnvironmentPage(this))
- parent.registerListener(listener)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
new file mode 100644
index 0000000000000..c82730f524eb7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.exec
+
+import java.net.URLDecoder
+import javax.servlet.http.HttpServletRequest
+
+import scala.util.Try
+import scala.xml.{Text, Node}
+
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") {
+
+ private val sc = parent.sc
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val executorId = Option(request.getParameter("executorId")).map {
+ executorId =>
+ // Due to YARN-2844, "" in the url will be encoded to "%25253Cdriver%25253E" when
+ // running in yarn-cluster mode. `request.getParameter("executorId")` will return
+ // "%253Cdriver%253E". Therefore we need to decode it until we get the real id.
+ var id = executorId
+ var decodedId = URLDecoder.decode(id, "UTF-8")
+ while (id != decodedId) {
+ id = decodedId
+ decodedId = URLDecoder.decode(id, "UTF-8")
+ }
+ id
+ }.getOrElse {
+ return Text(s"Missing executorId parameter")
+ }
+ val time = System.currentTimeMillis()
+ val maybeThreadDump = sc.get.getExecutorThreadDump(executorId)
+
+ val content = maybeThreadDump.map { threadDump =>
+ val dumpRows = threadDump.map { thread =>
+
+ } else {
+ Seq.empty
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index 61eb111cd9100..dd1c2b78c4094 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -26,10 +26,15 @@ import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.{SparkUI, SparkUITab}
private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") {
- val listener = new ExecutorsListener(parent.storageStatusListener)
+ val listener = parent.executorsListener
+ val sc = parent.sc
+ val threadDumpEnabled =
+ sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true)
- attachPage(new ExecutorsPage(this))
- parent.registerListener(listener)
+ attachPage(new ExecutorsPage(this, threadDumpEnabled))
+ if (threadDumpEnabled) {
+ attachPage(new ExecutorThreadDumpPage(this))
+ }
}
/**
@@ -43,20 +48,21 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp
val executorToTasksFailed = HashMap[String, Int]()
val executorToDuration = HashMap[String, Long]()
val executorToInputBytes = HashMap[String, Long]()
+ val executorToOutputBytes = HashMap[String, Long]()
val executorToShuffleRead = HashMap[String, Long]()
val executorToShuffleWrite = HashMap[String, Long]()
def storageStatusList = storageStatusListener.storageStatusList
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
- val eid = formatExecutorId(taskStart.taskInfo.executorId)
+ val eid = taskStart.taskInfo.executorId
executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val info = taskEnd.taskInfo
if (info != null) {
- val eid = formatExecutorId(info.executorId)
+ val eid = info.executorId
executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1
executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration
taskEnd.reason match {
@@ -73,6 +79,10 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp
executorToInputBytes(eid) =
executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead
}
+ metrics.outputMetrics.foreach { outputMetrics =>
+ executorToOutputBytes(eid) =
+ executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten
+ }
metrics.shuffleReadMetrics.foreach { shuffleRead =>
executorToShuffleRead(eid) =
executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead
@@ -85,6 +95,4 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp
}
}
- // This addresses executor ID inconsistencies in the local mode
- private def formatExecutorId(execId: String) = storageStatusListener.formatExecutorId(execId)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
new file mode 100644
index 0000000000000..045c69da06feb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.xml.{Node, NodeSeq}
+
+import javax.servlet.http.HttpServletRequest
+
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.ui.jobs.UIData.JobUIData
+
+/** Page showing list of all ongoing and recently finished jobs */
+private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
+ private val startTime: Option[Long] = parent.sc.map(_.startTime)
+ private val listener = parent.listener
+
+ private def jobsTable(jobs: Seq[JobUIData]): Seq[Node] = {
+ val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
+
+ val columns: Seq[Node] = {
+
{if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"}
+
Description
+
Submitted
+
Duration
+
Stages: Succeeded/Total
+
Tasks (for all stages): Succeeded/Total
+ }
+
+ def makeRow(job: JobUIData): Seq[Node] = {
+ val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max)
+ val lastStageData = lastStageInfo.flatMap { s =>
+ listener.stageIdToData.get((s.stageId, s.attemptId))
+ }
+
+ val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)")
+ val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("")
+ val duration: Option[Long] = {
+ job.submissionTime.map { start =>
+ val end = job.completionTime.getOrElse(System.currentTimeMillis())
+ end - start
+ }
+ }
+ val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
+ val formattedSubmissionTime = job.submissionTime.map(UIUtils.formatDate).getOrElse("Unknown")
+ val detailUrl =
+ "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId)
+
++
+ failedJobsTable
+ }
+ val helpText = """A job is triggered by an action, like "count()" or "saveAsTextFile()".""" +
+ " Click on a job's title to see information about the stages of tasks associated with" +
+ " the job."
+
+ UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText))
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
new file mode 100644
index 0000000000000..527f960af2dfc
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.xml.{Node, NodeSeq}
+
+import org.apache.spark.scheduler.Schedulable
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+
+/** Page showing list of all ongoing and recently finished stages and pools */
+private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") {
+ private val sc = parent.sc
+ private val listener = parent.listener
+ private def isFairScheduler = parent.isFairScheduler
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ listener.synchronized {
+ val activeStages = listener.activeStages.values.toSeq
+ val pendingStages = listener.pendingStages.values.toSeq
+ val completedStages = listener.completedStages.reverse.toSeq
+ val numCompletedStages = listener.numCompletedStages
+ val failedStages = listener.failedStages.reverse.toSeq
+ val numFailedStages = listener.numFailedStages
+ val now = System.currentTimeMillis
+
+ val activeStagesTable =
+ new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = parent.killEnabled)
+ val pendingStagesTable =
+ new StageTableBase(pendingStages.sortBy(_.submissionTime).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = false)
+ val completedStagesTable =
+ new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false)
+ val failedStagesTable =
+ new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler)
+
+ // For now, pool information is only accessible in live UIs
+ val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable])
+ val poolTable = new PoolTable(pools, parent)
+
+ val shouldShowActiveStages = activeStages.nonEmpty
+ val shouldShowPendingStages = pendingStages.nonEmpty
+ val shouldShowCompletedStages = completedStages.nonEmpty
+ val shouldShowFailedStages = failedStages.nonEmpty
+
+ val summary: NodeSeq =
+
+
+ {
+ if (sc.isDefined) {
+ // Total duration is not meaningful unless the UI is live
+
+ Total Duration:
+ {UIUtils.formatDuration(now - sc.get.startTime)}
+
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
new file mode 100644
index 0000000000000..77d36209c6048
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.collection.mutable
+import scala.xml.{NodeSeq, Node}
+
+import javax.servlet.http.HttpServletRequest
+
+import org.apache.spark.JobExecutionStatus
+import org.apache.spark.scheduler.StageInfo
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+/** Page showing statistics and stage list for a given job */
+private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
+ private val listener = parent.listener
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ listener.synchronized {
+ val jobId = request.getParameter("id").toInt
+ val jobDataOption = listener.jobIdToData.get(jobId)
+ if (jobDataOption.isEmpty) {
+ val content =
+
+
No information to display for job {jobId}
+
+ return UIUtils.headerSparkPage(
+ s"Details for Job $jobId", content, parent)
+ }
+ val jobData = jobDataOption.get
+ val isComplete = jobData.status != JobExecutionStatus.RUNNING
+ val stages = jobData.stageIds.map { stageId =>
+ // This could be empty if the JobProgressListener hasn't received information about the
+ // stage or if the stage information has been garbage collected
+ listener.stageIdToInfo.getOrElse(stageId,
+ new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, "Unknown"))
+ }
+
+ val activeStages = mutable.Buffer[StageInfo]()
+ val completedStages = mutable.Buffer[StageInfo]()
+ // If the job is completed, then any pending stages are displayed as "skipped":
+ val pendingOrSkippedStages = mutable.Buffer[StageInfo]()
+ val failedStages = mutable.Buffer[StageInfo]()
+ for (stage <- stages) {
+ if (stage.submissionTime.isEmpty) {
+ pendingOrSkippedStages += stage
+ } else if (stage.completionTime.isDefined) {
+ if (stage.failureReason.isDefined) {
+ failedStages += stage
+ } else {
+ completedStages += stage
+ }
+ } else {
+ activeStages += stage
+ }
+ }
+
+ val activeStagesTable =
+ new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = parent.killEnabled)
+ val pendingOrSkippedStagesTable =
+ new StageTableBase(pendingOrSkippedStages.sortBy(_.stageId).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = false)
+ val completedStagesTable =
+ new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false)
+ val failedStagesTable =
+ new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler)
+
+ val shouldShowActiveStages = activeStages.nonEmpty
+ val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty
+ val shouldShowCompletedStages = completedStages.nonEmpty
+ val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty
+ val shouldShowFailedStages = failedStages.nonEmpty
+
+ val summary: NodeSeq =
+
++
+ failedStagesTable.toNodeSeq
+ }
+ UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index eaeb861f59e5a..4d200eeda86b9 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.jobs
-import scala.collection.mutable.{HashMap, ListBuffer}
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
@@ -40,29 +40,185 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
import JobProgressListener._
- // How many stages to remember
- val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES)
+ // Define a handful of type aliases so that data structures' types can serve as documentation.
+ // These type aliases are public because they're used in the types of public fields:
- // Map from stageId to StageInfo
- val activeStages = new HashMap[Int, StageInfo]
+ type JobId = Int
+ type StageId = Int
+ type StageAttemptId = Int
+ type PoolName = String
+ type ExecutorId = String
- // Map from (stageId, attemptId) to StageUIData
- val stageIdToData = new HashMap[(Int, Int), StageUIData]
+ // Jobs:
+ val activeJobs = new HashMap[JobId, JobUIData]
+ val completedJobs = ListBuffer[JobUIData]()
+ val failedJobs = ListBuffer[JobUIData]()
+ val jobIdToData = new HashMap[JobId, JobUIData]
+ // Stages:
+ val pendingStages = new HashMap[StageId, StageInfo]
+ val activeStages = new HashMap[StageId, StageInfo]
val completedStages = ListBuffer[StageInfo]()
+ val skippedStages = ListBuffer[StageInfo]()
val failedStages = ListBuffer[StageInfo]()
+ val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData]
+ val stageIdToInfo = new HashMap[StageId, StageInfo]
+ val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]]
+ val poolToActiveStages = HashMap[PoolName, HashMap[StageId, StageInfo]]()
+ // Total of completed and failed stages that have ever been run. These may be greater than
+ // `completedStages.size` and `failedStages.size` if we have run more stages or jobs than
+ // JobProgressListener's retention limits.
+ var numCompletedStages = 0
+ var numFailedStages = 0
+
+ // Misc:
+ val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]()
+ def blockManagerIds = executorIdToBlockManagerId.values.toSeq
- // Map from pool name to a hash map (map from stage id to StageInfo).
- val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()
+ var schedulingMode: Option[SchedulingMode] = None
- val executorIdToBlockManagerId = HashMap[String, BlockManagerId]()
+ // To limit the total memory usage of JobProgressListener, we only track information for a fixed
+ // number of non-active jobs and stages (there is no limit for active jobs and stages):
- var schedulingMode: Option[SchedulingMode] = None
+ val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES)
+ val retainedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS)
+
+ // We can test for memory leaks by ensuring that collections that track non-active jobs and
+ // stages do not grow without bound and that collections for active jobs/stages eventually become
+ // empty once Spark is idle. Let's partition our collections into ones that should be empty
+ // once Spark is idle and ones that should have a hard- or soft-limited sizes.
+ // These methods are used by unit tests, but they're defined here so that people don't forget to
+ // update the tests when adding new collections. Some collections have multiple levels of
+ // nesting, etc, so this lets us customize our notion of "size" for each structure:
+
+ // These collections should all be empty once Spark is idle (no active stages / jobs):
+ private[spark] def getSizesOfActiveStateTrackingCollections: Map[String, Int] = {
+ Map(
+ "activeStages" -> activeStages.size,
+ "activeJobs" -> activeJobs.size,
+ "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum,
+ "stageIdToActiveJobIds" -> stageIdToActiveJobIds.values.map(_.size).sum
+ )
+ }
- def blockManagerIds = executorIdToBlockManagerId.values.toSeq
+ // These collections should stop growing once we have run at least `spark.ui.retainedStages`
+ // stages and `spark.ui.retainedJobs` jobs:
+ private[spark] def getSizesOfHardSizeLimitedCollections: Map[String, Int] = {
+ Map(
+ "completedJobs" -> completedJobs.size,
+ "failedJobs" -> failedJobs.size,
+ "completedStages" -> completedStages.size,
+ "skippedStages" -> skippedStages.size,
+ "failedStages" -> failedStages.size
+ )
+ }
+
+ // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to
+ // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings:
+ private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = {
+ Map(
+ "jobIdToData" -> jobIdToData.size,
+ "stageIdToData" -> stageIdToData.size,
+ "stageIdToStageInfo" -> stageIdToInfo.size
+ )
+ }
+
+ /** If stages is too large, remove and garbage collect old stages */
+ private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
+ if (stages.size > retainedStages) {
+ val toRemove = math.max(retainedStages / 10, 1)
+ stages.take(toRemove).foreach { s =>
+ stageIdToData.remove((s.stageId, s.attemptId))
+ stageIdToInfo.remove(s.stageId)
+ }
+ stages.trimStart(toRemove)
+ }
+ }
+
+ /** If jobs is too large, remove and garbage collect old jobs */
+ private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized {
+ if (jobs.size > retainedJobs) {
+ val toRemove = math.max(retainedJobs / 10, 1)
+ jobs.take(toRemove).foreach { job =>
+ jobIdToData.remove(job.jobId)
+ }
+ jobs.trimStart(toRemove)
+ }
+ }
+
+ override def onJobStart(jobStart: SparkListenerJobStart) = synchronized {
+ val jobGroup = for (
+ props <- Option(jobStart.properties);
+ group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID))
+ ) yield group
+ val jobData: JobUIData =
+ new JobUIData(
+ jobId = jobStart.jobId,
+ submissionTime = Option(jobStart.time).filter(_ >= 0),
+ stageIds = jobStart.stageIds,
+ jobGroup = jobGroup,
+ status = JobExecutionStatus.RUNNING)
+ jobStart.stageInfos.foreach(x => pendingStages(x.stageId) = x)
+ // Compute (a potential underestimate of) the number of tasks that will be run by this job.
+ // This may be an underestimate because the job start event references all of the result
+ // stages' transitive stage dependencies, but some of these stages might be skipped if their
+ // output is available from earlier runs.
+ // See https://github.com/apache/spark/pull/3009 for a more extensive discussion.
+ jobData.numTasks = {
+ val allStages = jobStart.stageInfos
+ val missingStages = allStages.filter(_.completionTime.isEmpty)
+ missingStages.map(_.numTasks).sum
+ }
+ jobIdToData(jobStart.jobId) = jobData
+ activeJobs(jobStart.jobId) = jobData
+ for (stageId <- jobStart.stageIds) {
+ stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId)
+ }
+ // If there's no information for a stage, store the StageInfo received from the scheduler
+ // so that we can display stage descriptions for pending stages:
+ for (stageInfo <- jobStart.stageInfos) {
+ stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo)
+ stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptId), new StageUIData)
+ }
+ }
+
+ override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
+ val jobData = activeJobs.remove(jobEnd.jobId).getOrElse {
+ logWarning(s"Job completed for unknown job ${jobEnd.jobId}")
+ new JobUIData(jobId = jobEnd.jobId)
+ }
+ jobData.completionTime = Option(jobEnd.time).filter(_ >= 0)
+
+ jobData.stageIds.foreach(pendingStages.remove)
+ jobEnd.jobResult match {
+ case JobSucceeded =>
+ completedJobs += jobData
+ trimJobsIfNecessary(completedJobs)
+ jobData.status = JobExecutionStatus.SUCCEEDED
+ case JobFailed(exception) =>
+ failedJobs += jobData
+ trimJobsIfNecessary(failedJobs)
+ jobData.status = JobExecutionStatus.FAILED
+ }
+ for (stageId <- jobData.stageIds) {
+ stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage =>
+ jobsUsingStage.remove(jobEnd.jobId)
+ stageIdToInfo.get(stageId).foreach { stageInfo =>
+ if (stageInfo.submissionTime.isEmpty) {
+ // if this stage is pending, it won't complete, so mark it as "skipped":
+ skippedStages += stageInfo
+ trimStagesIfNecessary(skippedStages)
+ jobData.numSkippedStages += 1
+ jobData.numSkippedTasks += stageInfo.numTasks
+ }
+ }
+ }
+ }
+ }
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
val stage = stageCompleted.stageInfo
+ stageIdToInfo(stage.stageId) = stage
val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), {
logWarning("Stage completed for unknown stage " + stage.stageId)
new StageUIData
@@ -78,19 +234,25 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
activeStages.remove(stage.stageId)
if (stage.failureReason.isEmpty) {
completedStages += stage
- trimIfNecessary(completedStages)
+ numCompletedStages += 1
+ trimStagesIfNecessary(completedStages)
} else {
failedStages += stage
- trimIfNecessary(failedStages)
+ numFailedStages += 1
+ trimStagesIfNecessary(failedStages)
}
- }
- /** If stages is too large, remove and garbage collect old stages */
- private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
- if (stages.size > retainedStages) {
- val toRemove = math.max(retainedStages / 10, 1)
- stages.take(toRemove).foreach { s => stageIdToData.remove((s.stageId, s.attemptId)) }
- stages.trimStart(toRemove)
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveStages -= 1
+ if (stage.failureReason.isEmpty) {
+ jobData.completedStageIndices.add(stage.stageId)
+ } else {
+ jobData.numFailedStages += 1
+ }
}
}
@@ -98,11 +260,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = synchronized {
val stage = stageSubmitted.stageInfo
activeStages(stage.stageId) = stage
-
+ pendingStages.remove(stage.stageId)
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
+ stageIdToInfo(stage.stageId) = stage
val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData)
stageData.schedulingPool = poolName
@@ -112,6 +275,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo])
stages(stage.stageId) = stage
+
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveStages += 1
+ }
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
@@ -124,6 +295,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData.numActiveTasks += 1
stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo))
}
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveTasks += 1
+ }
}
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
@@ -134,7 +312,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val info = taskEnd.taskInfo
// If stage attempt id is -1, it means the DAGScheduler had no idea which attempt this task
- // compeletion event is for. Let's just drop it here. This means we might have some speculation
+ // completion event is for. Let's just drop it here. This means we might have some speculation
// tasks on the web ui that's never marked as complete.
if (info != null && taskEnd.stageAttemptId != -1) {
val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), {
@@ -181,6 +359,20 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
taskData.taskInfo = info
taskData.taskMetrics = metrics
taskData.errorMessage = errorMessage
+
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveTasks -= 1
+ taskEnd.reason match {
+ case Success =>
+ jobData.numCompletedTasks += 1
+ case _ =>
+ jobData.numFailedTasks += 1
+ }
+ }
}
}
@@ -214,6 +406,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData.inputBytes += inputBytesDelta
execSummary.inputBytes += inputBytesDelta
+ val outputBytesDelta =
+ (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L)
+ - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L))
+ stageData.outputBytes += outputBytesDelta
+ execSummary.outputBytes += outputBytesDelta
+
val diskSpillDelta =
taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L)
stageData.diskBytesSpilled += diskSpillDelta
@@ -277,4 +475,5 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
private object JobProgressListener {
val DEFAULT_POOL_NAME = "default"
val DEFAULT_RETAINED_STAGES = 1000
+ val DEFAULT_RETAINED_JOBS = 1000
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
deleted file mode 100644
index 1e02f1225d344..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
+++ /dev/null
@@ -1,98 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.jobs
-
-import javax.servlet.http.HttpServletRequest
-
-import scala.xml.{Node, NodeSeq}
-
-import org.apache.spark.scheduler.Schedulable
-import org.apache.spark.ui.{WebUIPage, UIUtils}
-
-/** Page showing list of all ongoing and recently finished stages and pools */
-private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") {
- private val live = parent.live
- private val sc = parent.sc
- private val listener = parent.listener
- private def isFairScheduler = parent.isFairScheduler
-
- def render(request: HttpServletRequest): Seq[Node] = {
- listener.synchronized {
- val activeStages = listener.activeStages.values.toSeq
- val completedStages = listener.completedStages.reverse.toSeq
- val failedStages = listener.failedStages.reverse.toSeq
- val now = System.currentTimeMillis
-
- val activeStagesTable =
- new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
- parent, parent.killEnabled)
- val completedStagesTable =
- new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent)
- val failedStagesTable =
- new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent)
-
- // For now, pool information is only accessible in live UIs
- val pools = if (live) sc.getAllPools else Seq[Schedulable]()
- val poolTable = new PoolTable(pools, parent)
-
- val summary: NodeSeq =
-
-
- {if (live) {
- // Total duration is not meaningful unless the UI is live
-
- Total Duration:
- {UIUtils.formatDuration(now - sc.startTime)}
-
++
- failedStagesTable.toNodeSeq
-
- UIUtils.headerSparkPage("Spark Stages", content, parent)
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala
deleted file mode 100644
index c16542c9db30f..0000000000000
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala
+++ /dev/null
@@ -1,55 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ui.jobs
-
-import javax.servlet.http.HttpServletRequest
-
-import org.apache.spark.SparkConf
-import org.apache.spark.scheduler.SchedulingMode
-import org.apache.spark.ui.{SparkUI, SparkUITab}
-
-/** Web UI showing progress status of all jobs in the given SparkContext. */
-private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
- val live = parent.live
- val sc = parent.sc
- val conf = if (live) sc.conf else new SparkConf
- val killEnabled = conf.getBoolean("spark.ui.killEnabled", true)
- val listener = new JobProgressListener(conf)
-
- attachPage(new JobProgressPage(this))
- attachPage(new StagePage(this))
- attachPage(new PoolPage(this))
- parent.registerListener(listener)
-
- def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR)
-
- def handleKillRequest(request: HttpServletRequest) = {
- if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) {
- val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
- val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt
- if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) {
- sc.cancelStage(stageId)
- }
- // Do a quick pause here to give Spark time to kill the stage so it shows up as
- // killed after the refresh. Note that this will block the serving thread so the
- // time should be limited in duration.
- Thread.sleep(100)
- }
- }
-
-}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
new file mode 100644
index 0000000000000..b2bbfdee56946
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import org.apache.spark.scheduler.SchedulingMode
+import org.apache.spark.ui.{SparkUI, SparkUITab}
+
+/** Web UI showing progress status of all jobs in the given SparkContext. */
+private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
+ val sc = parent.sc
+ val killEnabled = parent.killEnabled
+ def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR)
+ val listener = parent.jobProgressListener
+
+ attachPage(new AllJobsPage(this))
+ attachPage(new JobPage(this))
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
index 7a6c7d1a497ed..5fc6cc7533150 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -25,8 +25,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.{WebUIPage, UIUtils}
/** Page showing specific pool details */
-private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {
- private val live = parent.live
+private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {
private val sc = parent.sc
private val listener = parent.listener
@@ -38,11 +37,12 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {
case Some(s) => s.values.toSeq
case None => Seq[StageInfo]()
}
- val activeStagesTable =
- new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent)
+ val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = parent.killEnabled)
// For now, pool information is only accessible in live UIs
- val pools = if (live) Seq(sc.getPoolForName(poolName).get) else Seq[Schedulable]()
+ val pools = sc.map(_.getPoolForName(poolName).get).toSeq
val poolTable = new PoolTable(pools, parent)
val content =
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index 64178e1e33d41..df1899e7a9b84 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -24,7 +24,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
-private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) {
+private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) {
private val listener = parent.listener
def toNodeSeq: Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 2414e4c65237e..d8be1b20b3acd 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -22,13 +22,16 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.{Node, Unparsed}
+import org.apache.commons.lang3.StringEscapeUtils
+
+import org.apache.spark.executor.TaskMetrics
import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils}
import org.apache.spark.ui.jobs.UIData._
import org.apache.spark.util.{Utils, Distribution}
-import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
/** Page showing statistics and task list for a given stage */
-private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
+private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
private val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
@@ -52,12 +55,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
val numCompleted = tasks.count(_.taskInfo.finished)
val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables
+ val hasAccumulators = accumulables.size > 0
val hasInput = stageData.inputBytes > 0
+ val hasOutput = stageData.outputBytes > 0
val hasShuffleRead = stageData.shuffleReadBytes > 0
val hasShuffleWrite = stageData.shuffleWriteBytes > 0
val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0
- // scalastyle:off
val summary =
@@ -65,55 +69,132 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
Total task time across all tasks:
{UIUtils.formatDuration(stageData.executorRunTime)}
- {if (hasInput)
+ {if (hasInput) {
+:
+ getFormattedTimeQuantiles(gettingResultTimes)
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
- val totalExecutionTime = {
- if (info.gettingResultTime > 0) {
- (info.gettingResultTime - info.launchTime).toDouble
- } else {
- (info.finishTime - info.launchTime).toDouble
- }
- }
- totalExecutionTime - metrics.get.executorRunTime
+ getSchedulerDelay(info, metrics.get).toDouble
}
val schedulerDelayTitle =
- Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
+ // The summary table does not use CSS to stripe rows, which doesn't work with hidden
+ // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows).
+ Some(UIUtils.listingTable(
+ quantileHeaders,
+ identity[Seq[Node]],
+ listings,
+ fixedWidth = true,
+ id = Some("task-summary-table"),
+ stripeRowsWithCss = false))
}
val executorTable = new ExecutorTable(stageId, stageAttemptId, parent)
@@ -221,6 +360,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
val content =
summary ++
+ showAdditionalMetrics ++
Summary Metrics for {numCompleted} Completed Tasks
++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++
Aggregated Metrics by Executor
++ executorTable.toNodeSeq ++
@@ -232,7 +372,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
}
def taskRow(
+ hasAccumulators: Boolean,
hasInput: Boolean,
+ hasOutput: Boolean,
hasShuffleRead: Boolean,
hasShuffleWrite: Boolean,
hasBytesSpilled: Boolean)(taskData: TaskUIData): Seq[Node] = {
@@ -241,8 +383,14 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
else metrics.map(_.executorRunTime).getOrElse(1L)
val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration)
else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("")
+ val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L)
val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L)
+ val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L)
val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L)
+ val gettingResultTime = info.gettingResultTime
+
+ val maybeAccumulators = info.accumulables
+ val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"}
val maybeInput = metrics.flatMap(_.inputMetrics)
val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("")
@@ -250,6 +398,17 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
.map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})")
.getOrElse("")
+ val maybeOutput = metrics.flatMap(_.outputMetrics)
+ val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("")
+ val outputReadable = maybeOutput
+ .map(m => s"${Utils.bytesToString(m.bytesWritten)}")
+ .getOrElse("")
+
+ val maybeShuffleReadBlockedTime = metrics.flatMap(_.shuffleReadMetrics).map(_.fetchWaitTime)
+ val shuffleReadBlockedTimeSortable = maybeShuffleReadBlockedTime.map(_.toString).getOrElse("")
+ val shuffleReadBlockedTimeReadable =
+ maybeShuffleReadBlockedTime.map(ms => UIUtils.formatDuration(ms)).getOrElse("")
+
val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead)
val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("")
val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("")
@@ -287,26 +446,45 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
++
@@ -43,6 +44,7 @@ private[ui] class StageTableBase(
Duration
Tasks: Succeeded/Total
Input
+
Output
Shuffle Read
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- build.dir
- ${user.dir}/build
-
-
-
- build.dir.hive
- ${build.dir}/hive
-
-
-
- hadoop.tmp.dir
- ${build.dir.hive}/test/hadoop-${user.name}
- A base for other temporary directories.
-
-
-
-
-
- hive.exec.scratchdir
- ${build.dir}/scratchdir
- Scratch space for Hive jobs
-
-
-
- hive.exec.local.scratchdir
- ${build.dir}/localscratchdir/
- Local scratch space for Hive jobs
-
-
-
- javax.jdo.option.ConnectionURL
-
- jdbc:derby:;databaseName=../build/test/junit_metastore_db;create=true
-
-
-
- javax.jdo.option.ConnectionDriverName
- org.apache.derby.jdbc.EmbeddedDriver
-
-
-
- javax.jdo.option.ConnectionUserName
- APP
-
-
-
- javax.jdo.option.ConnectionPassword
- mine
-
-
-
-
- hive.metastore.warehouse.dir
- ${test.warehouse.dir}
-
-
-
-
- hive.metastore.metadb.dir
- ${build.dir}/test/data/metadb/
-
- Required by metastore server or if the uris argument below is not supplied
-
-
-
-
- test.log.dir
- ${build.dir}/test/logs
-
-
-
-
- test.src.dir
- ${build.dir}/src/test
-
-
-
-
-
-
- hive.jar.path
- ${build.dir.hive}/ql/hive-exec-${version}.jar
-
-
-
-
- hive.metastore.rawstore.impl
- org.apache.hadoop.hive.metastore.ObjectStore
- Name of the class that implements org.apache.hadoop.hive.metastore.rawstore interface. This class is used to store and retrieval of raw metadata objects such as table, database
-
-
-
- hive.querylog.location
- ${build.dir}/tmp
- Location of the structured hive logs
-
-
-
-
-
- hive.task.progress
- false
- Track progress of a task
-
-
-
- hive.support.concurrency
- false
- Whether hive supports concurrency or not. A zookeeper instance must be up and running for the default hive lock manager to support read-write locks.
-
-
-
- fs.pfile.impl
- org.apache.hadoop.fs.ProxyLocalFileSystem
- A proxy for local file system used for cross file system testing
-
-
-
- hive.exec.mode.local.auto
- false
-
- Let hive determine whether to run in local mode automatically
- Disabling this for tests so that minimr is not affected
-
-
-
-
- hive.auto.convert.join
- false
- Whether Hive enable the optimization about converting common join into mapjoin based on the input file size
-
-
-
- hive.ignore.mapjoin.hint
- false
- Whether Hive ignores the mapjoin hint
-
-
-
- hive.input.format
- org.apache.hadoop.hive.ql.io.CombineHiveInputFormat
- The default input format, if it is not specified, the system assigns it. It is set to HiveInputFormat for hadoop versions 17, 18 and 19, whereas it is set to CombineHiveInputFormat for hadoop 20. The user can always overwrite it - if there is a bug in CombineHiveInputFormat, it can always be manually set to HiveInputFormat.
-
-
-
- hive.default.rcfile.serde
- org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe
- The default SerDe hive will use for the rcfile format
-
-
-
diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh
new file mode 100755
index 0000000000000..7473c20d28e09
--- /dev/null
+++ b/dev/change-version-to-2.10.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+find . -name 'pom.xml' | grep -v target \
+ | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.11|\1_2.10|g' {}
diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh
new file mode 100755
index 0000000000000..3957a9f3ba258
--- /dev/null
+++ b/dev/change-version-to-2.11.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+find . -name 'pom.xml' | grep -v target \
+ | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.10|\1_2.11|g' {}
diff --git a/dev/check-license b/dev/check-license
index 72b1013479964..a006f65710d6d 100755
--- a/dev/check-license
+++ b/dev/check-license
@@ -27,17 +27,17 @@ acquire_rat_jar () {
if [[ ! -f "$rat_jar" ]]; then
# Download rat launch jar if it hasn't been downloaded yet
if [ ! -f "$JAR" ]; then
- # Download
- printf "Attempting to fetch rat\n"
- JAR_DL="${JAR}.part"
- if hash curl 2>/dev/null; then
- curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR"
- elif hash wget 2>/dev/null; then
- wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR"
- else
- printf "You do not have curl or wget installed, please install rat manually.\n"
- exit -1
- fi
+ # Download
+ printf "Attempting to fetch rat\n"
+ JAR_DL="${JAR}.part"
+ if [ $(command -v curl) ]; then
+ curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR"
+ elif [ $(command -v wget) ]; then
+ wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR"
+ else
+ printf "You do not have curl or wget installed, please install rat manually.\n"
+ exit -1
+ fi
fi
unzip -tq $JAR &> /dev/null
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 281e8d4de6d71..607ce1c803507 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -22,18 +22,25 @@
# Expects to be run in a totally empty directory.
#
# Options:
-# --package-only only packages an existing release candidate
-#
+# --skip-create-release Assume the desired release tag already exists
+# --skip-publish Do not publish to Maven central
+# --skip-package Do not package and upload binary artifacts
# Would be nice to add:
# - Send output to stderr and have useful logging in stdout
-GIT_USERNAME=${GIT_USERNAME:-pwendell}
-GIT_PASSWORD=${GIT_PASSWORD:-XXX}
+# Note: The following variables must be set before use!
+ASF_USERNAME=${ASF_USERNAME:-pwendell}
+ASF_PASSWORD=${ASF_PASSWORD:-XXX}
GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX}
GIT_BRANCH=${GIT_BRANCH:-branch-1.0}
-RELEASE_VERSION=${RELEASE_VERSION:-1.0.0}
+RELEASE_VERSION=${RELEASE_VERSION:-1.2.0}
+NEXT_VERSION=${NEXT_VERSION:-1.2.1}
RC_NAME=${RC_NAME:-rc2}
-USER_NAME=${USER_NAME:-pwendell}
+
+M2_REPO=~/.m2/repository
+SPARK_REPO=$M2_REPO/org/apache/spark
+NEXUS_ROOT=https://repository.apache.org/service/local/staging
+NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads
if [ -z "$JAVA_HOME" ]; then
echo "Error: JAVA_HOME is not set, cannot proceed."
@@ -45,109 +52,201 @@ set -e
GIT_TAG=v$RELEASE_VERSION-$RC_NAME
-if [[ ! "$@" =~ --package-only ]]; then
- echo "Creating and publishing release"
+if [[ ! "$@" =~ --skip-create-release ]]; then
+ echo "Creating release commit and publishing to Apache repository"
# Artifact publishing
- git clone https://git-wip-us.apache.org/repos/asf/spark.git -b $GIT_BRANCH
- cd spark
+ git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \
+ -b $GIT_BRANCH
+ pushd spark
export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g"
- mvn -Pyarn release:clean
-
- mvn -DskipTests \
- -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
- -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \
- -Dmaven.javadoc.skip=true \
- -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Dtag=$GIT_TAG -DautoVersionSubmodules=true \
- -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
- --batch-mode release:prepare
+ # Create release commits and push them to github
+ # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build
+ # or before we coin the release commit. This helps avoid races where
+ # other people add commits to this branch while we are in the middle of building.
+ cur_ver="${RELEASE_VERSION}-SNAPSHOT"
+ rel_ver="${RELEASE_VERSION}"
+ next_ver="${NEXT_VERSION}-SNAPSHOT"
+
+ old="^\( \{2,4\}\)${cur_ver}<\/version>$"
+ new="\1${rel_ver}<\/version>"
+ find . -name pom.xml | grep -v dev | xargs -I {} sed -i \
+ -e "s/${old}/${new}/" {}
+ find . -name package.scala | grep -v dev | xargs -I {} sed -i \
+ -e "s/${old}/${new}/" {}
+
+ git commit -a -m "Preparing Spark release $GIT_TAG"
+ echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH"
+ git tag $GIT_TAG
+
+ old="^\( \{2,4\}\)${rel_ver}<\/version>$"
+ new="\1${next_ver}<\/version>"
+ find . -name pom.xml | grep -v dev | xargs -I {} sed -i \
+ -e "s/$old/$new/" {}
+ find . -name package.scala | grep -v dev | xargs -I {} sed -i \
+ -e "s/${old}/${new}/" {}
+ git commit -a -m "Preparing development version $next_ver"
+ git push origin $GIT_TAG
+ git push origin HEAD:$GIT_BRANCH
+ popd
+ rm -rf spark
+fi
- mvn -DskipTests \
- -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
- -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Dmaven.javadoc.skip=true \
+if [[ ! "$@" =~ --skip-publish ]]; then
+ git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git
+ pushd spark
+ git checkout --force $GIT_TAG
+
+ # Using Nexus API documented here:
+ # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API
+ echo "Creating Nexus staging repository"
+ repo_request="Apache Spark $GIT_TAG"
+ out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \
+ -H "Content-Type:application/xml" -v \
+ $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start)
+ staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/")
+ echo "Created Nexus staging repository: $staged_repo_id"
+
+ rm -rf $SPARK_REPO
+
+ mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
-Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
- release:perform
+ clean install
- cd ..
+ ./dev/change-version-to-2.11.sh
+
+ mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
+ -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
+ clean install
+
+ ./dev/change-version-to-2.10.sh
+
+ pushd $SPARK_REPO
+
+ # Remove any extra files generated during install
+ find . -type f |grep -v \.jar |grep -v \.pom | xargs rm
+
+ echo "Creating hash and signature files"
+ for file in $(find . -type f)
+ do
+ echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file;
+ if [ $(command -v md5) ]; then
+ # Available on OS X; -q to keep only hash
+ md5 -q $file > $file.md5
+ else
+ # Available on Linux; cut to keep only hash
+ md5sum $file | cut -f1 -d' ' > $file.md5
+ fi
+ shasum -a 1 $file | cut -f1 -d' ' > $file.sha1
+ done
+
+ nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id
+ echo "Uplading files to $nexus_upload"
+ for file in $(find . -type f)
+ do
+ # strip leading ./
+ file_short=$(echo $file | sed -e "s/\.\///")
+ dest_url="$nexus_upload/org/apache/spark/$file_short"
+ echo " Uploading $file_short"
+ curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url
+ done
+
+ echo "Closing nexus staging repository"
+ repo_request="$staged_repo_idApache Spark $GIT_TAG"
+ out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \
+ -H "Content-Type:application/xml" -v \
+ $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish)
+ echo "Closed Nexus staging repository: $staged_repo_id"
+
+ popd
+ popd
rm -rf spark
fi
-# Source and binary tarballs
-echo "Packaging release tarballs"
-git clone https://git-wip-us.apache.org/repos/asf/spark.git
-cd spark
-git checkout --force $GIT_TAG
-release_hash=`git rev-parse HEAD`
-
-rm .gitignore
-rm -rf .git
-cd ..
-
-cp -r spark spark-$RELEASE_VERSION
-tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION
-echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \
- --detach-sig spark-$RELEASE_VERSION.tgz
-echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \
- spark-$RELEASE_VERSION.tgz.md5
-echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \
- spark-$RELEASE_VERSION.tgz.sha
-rm -rf spark-$RELEASE_VERSION
-
-make_binary_release() {
- NAME=$1
- FLAGS=$2
- cp -r spark spark-$RELEASE_VERSION-bin-$NAME
-
- cd spark-$RELEASE_VERSION-bin-$NAME
- ./make-distribution.sh --name $NAME --tgz $FLAGS
+if [[ ! "$@" =~ --skip-package ]]; then
+ # Source and binary tarballs
+ echo "Packaging release tarballs"
+ git clone https://git-wip-us.apache.org/repos/asf/spark.git
+ cd spark
+ git checkout --force $GIT_TAG
+ release_hash=`git rev-parse HEAD`
+
+ rm .gitignore
+ rm -rf .git
cd ..
- cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz .
- rm -rf spark-$RELEASE_VERSION-bin-$NAME
-
- echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \
- --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \
- --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz
- echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \
- MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \
- spark-$RELEASE_VERSION-bin-$NAME.tgz.md5
- echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \
- SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \
- spark-$RELEASE_VERSION-bin-$NAME.tgz.sha
-}
-
-make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" &
-make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" &
-make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" &
-make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" &
-make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" &
-make_binary_release "mapr3" "-Pmapr3 -Phive" &
-make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" &
-wait
-
-# Copy data
-echo "Copying release tarballs"
-rc_folder=spark-$RELEASE_VERSION-$RC_NAME
-ssh $USER_NAME@people.apache.org \
- mkdir /home/$USER_NAME/public_html/$rc_folder
-scp spark-* \
- $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_folder/
-
-# Docs
-cd spark
-sbt/sbt clean
-cd docs
-# Compile docs with Java 7 to use nicer format
-JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build
-echo "Copying release documentation"
-rc_docs_folder=${rc_folder}-docs
-ssh $USER_NAME@people.apache.org \
- mkdir /home/$USER_NAME/public_html/$rc_docs_folder
-rsync -r _site/* $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_docs_folder
-
-echo "Release $RELEASE_VERSION completed:"
-echo "Git tag:\t $GIT_TAG"
-echo "Release commit:\t $release_hash"
-echo "Binary location:\t http://people.apache.org/~$USER_NAME/$rc_folder"
-echo "Doc location:\t http://people.apache.org/~$USER_NAME/$rc_docs_folder"
+
+ cp -r spark spark-$RELEASE_VERSION
+ tar cvzf spark-$RELEASE_VERSION.tgz spark-$RELEASE_VERSION
+ echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour --output spark-$RELEASE_VERSION.tgz.asc \
+ --detach-sig spark-$RELEASE_VERSION.tgz
+ echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md MD5 spark-$RELEASE_VERSION.tgz > \
+ spark-$RELEASE_VERSION.tgz.md5
+ echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md SHA512 spark-$RELEASE_VERSION.tgz > \
+ spark-$RELEASE_VERSION.tgz.sha
+ rm -rf spark-$RELEASE_VERSION
+
+ make_binary_release() {
+ NAME=$1
+ FLAGS=$2
+ cp -r spark spark-$RELEASE_VERSION-bin-$NAME
+
+ cd spark-$RELEASE_VERSION-bin-$NAME
+
+ # TODO There should probably be a flag to make-distribution to allow 2.11 support
+ if [[ $FLAGS == *scala-2.11* ]]; then
+ ./dev/change-version-to-2.11.sh
+ fi
+
+ ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log
+ cd ..
+ cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz .
+ rm -rf spark-$RELEASE_VERSION-bin-$NAME
+
+ echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --armour \
+ --output spark-$RELEASE_VERSION-bin-$NAME.tgz.asc \
+ --detach-sig spark-$RELEASE_VERSION-bin-$NAME.tgz
+ echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \
+ MD5 spark-$RELEASE_VERSION-bin-$NAME.tgz > \
+ spark-$RELEASE_VERSION-bin-$NAME.tgz.md5
+ echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --print-md \
+ SHA512 spark-$RELEASE_VERSION-bin-$NAME.tgz > \
+ spark-$RELEASE_VERSION-bin-$NAME.tgz.sha
+ }
+
+
+ make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" &
+ make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" &
+ make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" &
+ make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" &
+ make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" &
+ make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" &
+ make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" &
+ make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" &
+ wait
+
+ # Copy data
+ echo "Copying release tarballs"
+ rc_folder=spark-$RELEASE_VERSION-$RC_NAME
+ ssh $ASF_USERNAME@people.apache.org \
+ mkdir /home/$ASF_USERNAME/public_html/$rc_folder
+ scp spark-* \
+ $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/
+
+ # Docs
+ cd spark
+ sbt/sbt clean
+ cd docs
+ # Compile docs with Java 7 to use nicer format
+ JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build
+ echo "Copying release documentation"
+ rc_docs_folder=${rc_folder}-docs
+ ssh $ASF_USERNAME@people.apache.org \
+ mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder
+ rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder
+
+ echo "Release $RELEASE_VERSION completed:"
+ echo "Git tag:\t $GIT_TAG"
+ echo "Release commit:\t $release_hash"
+ echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder"
+ echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder"
+fi
diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py
new file mode 100755
index 0000000000000..8aaa250bd7e29
--- /dev/null
+++ b/dev/create-release/generate-contributors.py
@@ -0,0 +1,244 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This script automates the process of creating release notes.
+
+import os
+import re
+import sys
+
+from releaseutils import *
+
+# You must set the following before use!
+JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira")
+RELEASE_TAG = os.environ.get("RELEASE_TAG", "v1.2.0-rc2")
+PREVIOUS_RELEASE_TAG = os.environ.get("PREVIOUS_RELEASE_TAG", "v1.1.0")
+
+# If the release tags are not provided, prompt the user to provide them
+while not tag_exists(RELEASE_TAG):
+ RELEASE_TAG = raw_input("Please provide a valid release tag: ")
+while not tag_exists(PREVIOUS_RELEASE_TAG):
+ print "Please specify the previous release tag."
+ PREVIOUS_RELEASE_TAG = raw_input(\
+ "For instance, if you are releasing v1.2.0, you should specify v1.1.0: ")
+
+# Gather commits found in the new tag but not in the old tag.
+# This filters commits based on both the git hash and the PR number.
+# If either is present in the old tag, then we ignore the commit.
+print "Gathering new commits between tags %s and %s" % (PREVIOUS_RELEASE_TAG, RELEASE_TAG)
+release_commits = get_commits(RELEASE_TAG)
+previous_release_commits = get_commits(PREVIOUS_RELEASE_TAG)
+previous_release_hashes = set()
+previous_release_prs = set()
+for old_commit in previous_release_commits:
+ previous_release_hashes.add(old_commit.get_hash())
+ if old_commit.get_pr_number():
+ previous_release_prs.add(old_commit.get_pr_number())
+new_commits = []
+for this_commit in release_commits:
+ this_hash = this_commit.get_hash()
+ this_pr_number = this_commit.get_pr_number()
+ if this_hash in previous_release_hashes:
+ continue
+ if this_pr_number and this_pr_number in previous_release_prs:
+ continue
+ new_commits.append(this_commit)
+if not new_commits:
+ sys.exit("There are no new commits between %s and %s!" % (PREVIOUS_RELEASE_TAG, RELEASE_TAG))
+
+# Prompt the user for confirmation that the commit range is correct
+print "\n=================================================================================="
+print "JIRA server: %s" % JIRA_API_BASE
+print "Release tag: %s" % RELEASE_TAG
+print "Previous release tag: %s" % PREVIOUS_RELEASE_TAG
+print "Number of commits in this range: %s" % len(new_commits)
+print
+def print_indented(_list):
+ for x in _list: print " %s" % x
+if yesOrNoPrompt("Show all commits?"):
+ print_indented(new_commits)
+print "==================================================================================\n"
+if not yesOrNoPrompt("Does this look correct?"):
+ sys.exit("Ok, exiting")
+
+# Filter out special commits
+releases = []
+maintenance = []
+reverts = []
+nojiras = []
+filtered_commits = []
+def is_release(commit_title):
+ return re.findall("\[release\]", commit_title.lower()) or\
+ "preparing spark release" in commit_title.lower() or\
+ "preparing development version" in commit_title.lower() or\
+ "CHANGES.txt" in commit_title
+def is_maintenance(commit_title):
+ return "maintenance" in commit_title.lower() or\
+ "manually close" in commit_title.lower()
+def has_no_jira(commit_title):
+ return not re.findall("SPARK-[0-9]+", commit_title.upper())
+def is_revert(commit_title):
+ return "revert" in commit_title.lower()
+def is_docs(commit_title):
+ return re.findall("docs*", commit_title.lower()) or\
+ "programming guide" in commit_title.lower()
+for c in new_commits:
+ t = c.get_title()
+ if not t: continue
+ elif is_release(t): releases.append(c)
+ elif is_maintenance(t): maintenance.append(c)
+ elif is_revert(t): reverts.append(c)
+ elif is_docs(t): filtered_commits.append(c) # docs may not have JIRA numbers
+ elif has_no_jira(t): nojiras.append(c)
+ else: filtered_commits.append(c)
+
+# Warn against ignored commits
+if releases or maintenance or reverts or nojiras:
+ print "\n=================================================================================="
+ if releases: print "Found %d release commits" % len(releases)
+ if maintenance: print "Found %d maintenance commits" % len(maintenance)
+ if reverts: print "Found %d revert commits" % len(reverts)
+ if nojiras: print "Found %d commits with no JIRA" % len(nojiras)
+ print "* Warning: these commits will be ignored.\n"
+ if yesOrNoPrompt("Show ignored commits?"):
+ if releases: print "Release (%d)" % len(releases); print_indented(releases)
+ if maintenance: print "Maintenance (%d)" % len(maintenance); print_indented(maintenance)
+ if reverts: print "Revert (%d)" % len(reverts); print_indented(reverts)
+ if nojiras: print "No JIRA (%d)" % len(nojiras); print_indented(nojiras)
+ print "==================== Warning: the above commits will be ignored ==================\n"
+prompt_msg = "%d commits left to process after filtering. Ok to proceed?" % len(filtered_commits)
+if not yesOrNoPrompt(prompt_msg):
+ sys.exit("Ok, exiting.")
+
+# Keep track of warnings to tell the user at the end
+warnings = []
+
+# Mapping from the invalid author name to its associated JIRA issues
+# E.g. andrewor14 -> set("SPARK-2413", "SPARK-3551", "SPARK-3471")
+invalid_authors = {}
+
+# Populate a map that groups issues and components by author
+# It takes the form: Author name -> { Contribution type -> Spark components }
+# For instance,
+# {
+# 'Andrew Or': {
+# 'bug fixes': ['windows', 'core', 'web ui'],
+# 'improvements': ['core']
+# },
+# 'Tathagata Das' : {
+# 'bug fixes': ['streaming']
+# 'new feature': ['streaming']
+# }
+# }
+#
+author_info = {}
+jira_options = { "server": JIRA_API_BASE }
+jira_client = JIRA(options = jira_options)
+print "\n=========================== Compiling contributor list ==========================="
+for commit in filtered_commits:
+ _hash = commit.get_hash()
+ title = commit.get_title()
+ issues = re.findall("SPARK-[0-9]+", title.upper())
+ author = commit.get_author()
+ date = get_date(_hash)
+ # If the author name is invalid, keep track of it along
+ # with all associated issues so we can translate it later
+ if is_valid_author(author):
+ author = capitalize_author(author)
+ else:
+ if author not in invalid_authors:
+ invalid_authors[author] = set()
+ for issue in issues:
+ invalid_authors[author].add(issue)
+ # Parse components from the commit title, if any
+ commit_components = find_components(title, _hash)
+ # Populate or merge an issue into author_info[author]
+ def populate(issue_type, components):
+ components = components or [CORE_COMPONENT] # assume core if no components provided
+ if author not in author_info:
+ author_info[author] = {}
+ if issue_type not in author_info[author]:
+ author_info[author][issue_type] = set()
+ for component in components:
+ author_info[author][issue_type].add(component)
+ # Find issues and components associated with this commit
+ for issue in issues:
+ jira_issue = jira_client.issue(issue)
+ jira_type = jira_issue.fields.issuetype.name
+ jira_type = translate_issue_type(jira_type, issue, warnings)
+ jira_components = [translate_component(c.name, _hash, warnings)\
+ for c in jira_issue.fields.components]
+ all_components = set(jira_components + commit_components)
+ populate(jira_type, all_components)
+ # For docs without an associated JIRA, manually add it ourselves
+ if is_docs(title) and not issues:
+ populate("documentation", commit_components)
+ print " Processed commit %s authored by %s on %s" % (_hash, author, date)
+print "==================================================================================\n"
+
+# Write to contributors file ordered by author names
+# Each line takes the format " * Author name -- semi-colon delimited contributions"
+# e.g. * Andrew Or -- Bug fixes in Windows, Core, and Web UI; improvements in Core
+# e.g. * Tathagata Das -- Bug fixes and new features in Streaming
+contributors_file = open(contributors_file_name, "w")
+authors = author_info.keys()
+authors.sort()
+for author in authors:
+ contribution = ""
+ components = set()
+ issue_types = set()
+ for issue_type, comps in author_info[author].items():
+ components.update(comps)
+ issue_types.add(issue_type)
+ # If there is only one component, mention it only once
+ # e.g. Bug fixes, improvements in MLlib
+ if len(components) == 1:
+ contribution = "%s in %s" % (nice_join(issue_types), next(iter(components)))
+ # Otherwise, group contributions by issue types instead of modules
+ # e.g. Bug fixes in MLlib, Core, and Streaming; documentation in YARN
+ else:
+ contributions = ["%s in %s" % (issue_type, nice_join(comps)) \
+ for issue_type, comps in author_info[author].items()]
+ contribution = "; ".join(contributions)
+ # Do not use python's capitalize() on the whole string to preserve case
+ assert contribution
+ contribution = contribution[0].capitalize() + contribution[1:]
+ # If the author name is invalid, use an intermediate format that
+ # can be translated through translate-contributors.py later
+ # E.g. andrewor14/SPARK-3425/SPARK-1157/SPARK-6672
+ if author in invalid_authors and invalid_authors[author]:
+ author = author + "/" + "/".join(invalid_authors[author])
+ line = " * %s -- %s" % (author, contribution)
+ contributors_file.write(line + "\n")
+contributors_file.close()
+print "Contributors list is successfully written to %s!" % contributors_file_name
+
+# Prompt the user to translate author names if necessary
+if invalid_authors:
+ warnings.append("Found the following invalid authors:")
+ for a in invalid_authors:
+ warnings.append("\t%s" % a)
+ warnings.append("Please run './translate-contributors.py' to translate them.")
+
+# Log any warnings encountered in the process
+if warnings:
+ print "\n============ Warnings encountered while creating the contributor list ============"
+ for w in warnings: print w
+ print "Please correct these in the final contributors list at %s." % contributors_file_name
+ print "==================================================================================\n"
+
diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations
new file mode 100644
index 0000000000000..b74e4ee8a330b
--- /dev/null
+++ b/dev/create-release/known_translations
@@ -0,0 +1,59 @@
+# This is a mapping of names to be translated through translate-contributors.py
+# The format expected on each line should be: -
+CodingCat - Nan Zhu
+CrazyJvm - Chao Chen
+EugenCepoi - Eugen Cepoi
+GraceH - Jie Huang
+JerryLead - Lijie Xu
+Leolh - Liu Hao
+Lewuathe - Kai Sasaki
+RongGu - Rong Gu
+Shiti - Shiti Saxena
+Victsm - Min Shen
+WangTaoTheTonic - Wang Tao
+XuTingjun - Tingjun Xu
+YanTangZhai - Yantang Zhai
+alexdebrie - Alex DeBrie
+alokito - Alok Saldanha
+anantasty - Anant Asthana
+andrewor14 - Andrew Or
+aniketbhatnagar - Aniket Bhatnagar
+arahuja - Arun Ahuja
+brkyvz - Burak Yavuz
+chesterxgchen - Chester Chen
+chiragaggarwal - Chirag Aggarwal
+chouqin - Qiping Li
+cocoatomo - Tomohiko K.
+coderfi - Fairiz Azizi
+coderxiang - Shuo Xiang
+davies - Davies Liu
+epahomov - Egor Pahomov
+falaki - Hossein Falaki
+freeman-lab - Jeremy Freeman
+industrial-sloth - Jascha Swisher
+jackylk - Jacky Li
+jayunit100 - Jay Vyas
+jerryshao - Saisai Shao
+jkbradley - Joseph Bradley
+lianhuiwang - Lianhui Wang
+lirui-intel - Rui Li
+luluorta - Lu Lu
+luogankun - Gankun Luo
+maji2014 - Derek Ma
+mccheah - Matthew Cheah
+mengxr - Xiangrui Meng
+nartz - Nathan Artz
+odedz - Oded Zimerman
+ravipesala - Ravindra Pesala
+roxchkplusony - Victor Tso
+scwf - Wang Fei
+shimingfei - Shiming Fei
+surq - Surong Quan
+suyanNone - Su Yan
+tedyu - Ted Yu
+tigerquoll - Dale Richardson
+wangxiaojing - Xiaojing Wang
+watermen - Yadong Qi
+witgo - Guoqiang Li
+xinyunh - Xinyun Huang
+zsxwing - Shixiong Zhu
diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py
new file mode 100755
index 0000000000000..26221b270394e
--- /dev/null
+++ b/dev/create-release/releaseutils.py
@@ -0,0 +1,256 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This file contains helper methods used in creating a release.
+
+import re
+import sys
+from subprocess import Popen, PIPE
+
+try:
+ from jira.client import JIRA
+ from jira.exceptions import JIRAError
+except ImportError:
+ print "This tool requires the jira-python library"
+ print "Install using 'sudo pip install jira-python'"
+ sys.exit(-1)
+
+try:
+ from github import Github
+ from github import GithubException
+except ImportError:
+ print "This tool requires the PyGithub library"
+ print "Install using 'sudo pip install PyGithub'"
+ sys.exit(-1)
+
+try:
+ import unidecode
+except ImportError:
+ print "This tool requires the unidecode library to decode obscure github usernames"
+ print "Install using 'sudo pip install unidecode'"
+ sys.exit(-1)
+
+# Contributors list file name
+contributors_file_name = "contributors.txt"
+
+# Prompt the user to answer yes or no until they do so
+def yesOrNoPrompt(msg):
+ response = raw_input("%s [y/n]: " % msg)
+ while response != "y" and response != "n":
+ return yesOrNoPrompt(msg)
+ return response == "y"
+
+# Utility functions run git commands (written with Git 1.8.5)
+def run_cmd(cmd): return Popen(cmd, stdout=PIPE).communicate()[0]
+def run_cmd_error(cmd): return Popen(cmd, stdout=PIPE, stderr=PIPE).communicate()[1]
+def get_date(commit_hash):
+ return run_cmd(["git", "show", "--quiet", "--pretty=format:%cd", commit_hash])
+def tag_exists(tag):
+ stderr = run_cmd_error(["git", "show", tag])
+ return "error" not in stderr
+
+# A type-safe representation of a commit
+class Commit:
+ def __init__(self, _hash, author, title, pr_number = None):
+ self._hash = _hash
+ self.author = author
+ self.title = title
+ self.pr_number = pr_number
+ def get_hash(self): return self._hash
+ def get_author(self): return self.author
+ def get_title(self): return self.title
+ def get_pr_number(self): return self.pr_number
+ def __str__(self):
+ closes_pr = "(Closes #%s)" % self.pr_number if self.pr_number else ""
+ return "%s %s %s %s" % (self._hash, self.author, self.title, closes_pr)
+
+# Return all commits that belong to the specified tag.
+#
+# Under the hood, this runs a `git log` on that tag and parses the fields
+# from the command output to construct a list of Commit objects. Note that
+# because certain fields reside in the commit description and cannot be parsed
+# through the Github API itself, we need to do some intelligent regex parsing
+# to extract those fields.
+#
+# This is written using Git 1.8.5.
+def get_commits(tag):
+ commit_start_marker = "|=== COMMIT START MARKER ===|"
+ commit_end_marker = "|=== COMMIT END MARKER ===|"
+ field_end_marker = "|=== COMMIT FIELD END MARKER ===|"
+ log_format =\
+ commit_start_marker + "%h" +\
+ field_end_marker + "%an" +\
+ field_end_marker + "%s" +\
+ commit_end_marker + "%b"
+ output = run_cmd(["git", "log", "--quiet", "--pretty=format:" + log_format, tag])
+ commits = []
+ raw_commits = [c for c in output.split(commit_start_marker) if c]
+ for commit in raw_commits:
+ if commit.count(commit_end_marker) != 1:
+ print "Commit end marker not found in commit: "
+ for line in commit.split("\n"): print line
+ sys.exit(1)
+ # Separate commit digest from the body
+ # From the digest we extract the hash, author and the title
+ # From the body, we extract the PR number and the github username
+ [commit_digest, commit_body] = commit.split(commit_end_marker)
+ if commit_digest.count(field_end_marker) != 2:
+ sys.exit("Unexpected format in commit: %s" % commit_digest)
+ [_hash, author, title] = commit_digest.split(field_end_marker)
+ # The PR number and github username is in the commit message
+ # itself and cannot be accessed through any Github API
+ pr_number = None
+ match = re.search("Closes #([0-9]+) from ([^/\\s]+)/", commit_body)
+ if match:
+ [pr_number, github_username] = match.groups()
+ # If the author name is not valid, use the github
+ # username so we can translate it properly later
+ if not is_valid_author(author):
+ author = github_username
+ # Guard against special characters
+ author = unidecode.unidecode(unicode(author, "UTF-8")).strip()
+ commit = Commit(_hash, author, title, pr_number)
+ commits.append(commit)
+ return commits
+
+# Maintain a mapping for translating issue types to contributions in the release notes
+# This serves an additional function of warning the user against unknown issue types
+# Note: This list is partially derived from this link:
+# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/issuetypes
+# Keep these in lower case
+known_issue_types = {
+ "bug": "bug fixes",
+ "build": "build fixes",
+ "dependency upgrade": "build fixes",
+ "improvement": "improvements",
+ "new feature": "new features",
+ "documentation": "documentation",
+ "test": "test",
+ "task": "improvement",
+ "sub-task": "improvement"
+}
+
+# Maintain a mapping for translating component names when creating the release notes
+# This serves an additional function of warning the user against unknown components
+# Note: This list is largely derived from this link:
+# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/components
+CORE_COMPONENT = "Core"
+known_components = {
+ "block manager": CORE_COMPONENT,
+ "build": CORE_COMPONENT,
+ "deploy": CORE_COMPONENT,
+ "documentation": CORE_COMPONENT,
+ "ec2": "EC2",
+ "examples": CORE_COMPONENT,
+ "graphx": "GraphX",
+ "input/output": CORE_COMPONENT,
+ "java api": "Java API",
+ "mesos": "Mesos",
+ "ml": "MLlib",
+ "mllib": "MLlib",
+ "project infra": "Project Infra",
+ "pyspark": "PySpark",
+ "shuffle": "Shuffle",
+ "spark core": CORE_COMPONENT,
+ "spark shell": CORE_COMPONENT,
+ "sql": "SQL",
+ "streaming": "Streaming",
+ "web ui": "Web UI",
+ "windows": "Windows",
+ "yarn": "YARN"
+}
+
+# Translate issue types using a format appropriate for writing contributions
+# If an unknown issue type is encountered, warn the user
+def translate_issue_type(issue_type, issue_id, warnings):
+ issue_type = issue_type.lower()
+ if issue_type in known_issue_types:
+ return known_issue_types[issue_type]
+ else:
+ warnings.append("Unknown issue type \"%s\" (see %s)" % (issue_type, issue_id))
+ return issue_type
+
+# Translate component names using a format appropriate for writing contributions
+# If an unknown component is encountered, warn the user
+def translate_component(component, commit_hash, warnings):
+ component = component.lower()
+ if component in known_components:
+ return known_components[component]
+ else:
+ warnings.append("Unknown component \"%s\" (see %s)" % (component, commit_hash))
+ return component
+
+# Parse components in the commit message
+# The returned components are already filtered and translated
+def find_components(commit, commit_hash):
+ components = re.findall("\[\w*\]", commit.lower())
+ components = [translate_component(c, commit_hash)\
+ for c in components if c in known_components]
+ return components
+
+# Join a list of strings in a human-readable manner
+# e.g. ["Juice"] -> "Juice"
+# e.g. ["Juice", "baby"] -> "Juice and baby"
+# e.g. ["Juice", "baby", "moon"] -> "Juice, baby, and moon"
+def nice_join(str_list):
+ str_list = list(str_list) # sometimes it's a set
+ if not str_list:
+ return ""
+ elif len(str_list) == 1:
+ return next(iter(str_list))
+ elif len(str_list) == 2:
+ return " and ".join(str_list)
+ else:
+ return ", ".join(str_list[:-1]) + ", and " + str_list[-1]
+
+# Return the full name of the specified user on Github
+# If the user doesn't exist, return None
+def get_github_name(author, github_client):
+ if github_client:
+ try:
+ return github_client.get_user(author).name
+ except GithubException as e:
+ # If this is not a "not found" exception
+ if e.status != 404:
+ raise e
+ return None
+
+# Return the full name of the specified user on JIRA
+# If the user doesn't exist, return None
+def get_jira_name(author, jira_client):
+ if jira_client:
+ try:
+ return jira_client.user(author).displayName
+ except JIRAError as e:
+ # If this is not a "not found" exception
+ if e.status_code != 404:
+ raise e
+ return None
+
+# Return whether the given name is in the form
+def is_valid_author(author):
+ if not author: return False
+ return " " in author and not re.findall("[0-9]", author)
+
+# Capitalize the first letter of each word in the given author name
+def capitalize_author(author):
+ if not author: return None
+ words = author.split(" ")
+ words = [w[0].capitalize() + w[1:] for w in words if w]
+ return " ".join(words)
+
diff --git a/dev/create-release/translate-contributors.py b/dev/create-release/translate-contributors.py
new file mode 100755
index 0000000000000..86fa02d87b9a0
--- /dev/null
+++ b/dev/create-release/translate-contributors.py
@@ -0,0 +1,253 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script translates invalid authors in the contributors list generated
+# by generate-contributors.py. When the script encounters an author name that
+# is considered invalid, it searches Github and JIRA in an attempt to search
+# for replacements. This tool runs in two modes:
+#
+# (1) Interactive mode: For each invalid author name, this script presents
+# all candidate replacements to the user and awaits user response. In this
+# mode, the user may also input a custom name. This is the default.
+#
+# (2) Non-interactive mode: For each invalid author name, this script replaces
+# the name with the first valid candidate it can find. If there is none, it
+# uses the original name. This can be enabled through the --non-interactive flag.
+
+import os
+import sys
+
+from releaseutils import *
+
+# You must set the following before use!
+JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira")
+JIRA_USERNAME = os.environ.get("JIRA_USERNAME", None)
+JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", None)
+GITHUB_API_TOKEN = os.environ.get("GITHUB_API_TOKEN", None)
+if not JIRA_USERNAME or not JIRA_PASSWORD:
+ sys.exit("Both JIRA_USERNAME and JIRA_PASSWORD must be set")
+if not GITHUB_API_TOKEN:
+ sys.exit("GITHUB_API_TOKEN must be set")
+
+# Write new contributors list to .final
+if not os.path.isfile(contributors_file_name):
+ print "Contributors file %s does not exist!" % contributors_file_name
+ print "Have you run ./generate-contributors.py yet?"
+ sys.exit(1)
+contributors_file = open(contributors_file_name, "r")
+warnings = []
+
+# In non-interactive mode, this script will choose the first replacement that is valid
+INTERACTIVE_MODE = True
+if len(sys.argv) > 1:
+ options = set(sys.argv[1:])
+ if "--non-interactive" in options:
+ INTERACTIVE_MODE = False
+if INTERACTIVE_MODE:
+ print "Running in interactive mode. To disable this, provide the --non-interactive flag."
+
+# Setup Github and JIRA clients
+jira_options = { "server": JIRA_API_BASE }
+jira_client = JIRA(options = jira_options, basic_auth = (JIRA_USERNAME, JIRA_PASSWORD))
+github_client = Github(GITHUB_API_TOKEN)
+
+# Load known author translations that are cached locally
+known_translations = {}
+known_translations_file_name = "known_translations"
+known_translations_file = open(known_translations_file_name, "r")
+for line in known_translations_file:
+ if line.startswith("#"): continue
+ [old_name, new_name] = line.strip("\n").split(" - ")
+ known_translations[old_name] = new_name
+known_translations_file.close()
+
+# Open again in case the user adds new mappings
+known_translations_file = open(known_translations_file_name, "a")
+
+# Generate candidates for the given author. This should only be called if the given author
+# name does not represent a full name as this operation is somewhat expensive. Under the
+# hood, it makes several calls to the Github and JIRA API servers to find the candidates.
+#
+# This returns a list of (candidate name, source) 2-tuples. E.g.
+# [
+# (NOT_FOUND, "No full name found for Github user andrewor14"),
+# ("Andrew Or", "Full name of JIRA user andrewor14"),
+# ("Andrew Orso", "Full name of SPARK-1444 assignee andrewor14"),
+# ("Andrew Ordall", "Full name of SPARK-1663 assignee andrewor14"),
+# (NOT_FOUND, "No assignee found for SPARK-1763")
+# ]
+NOT_FOUND = "Not found"
+def generate_candidates(author, issues):
+ candidates = []
+ # First check for full name of Github user
+ github_name = get_github_name(author, github_client)
+ if github_name:
+ candidates.append((github_name, "Full name of Github user %s" % author))
+ else:
+ candidates.append((NOT_FOUND, "No full name found for Github user %s" % author))
+ # Then do the same for JIRA user
+ jira_name = get_jira_name(author, jira_client)
+ if jira_name:
+ candidates.append((jira_name, "Full name of JIRA user %s" % author))
+ else:
+ candidates.append((NOT_FOUND, "No full name found for JIRA user %s" % author))
+ # Then do the same for the assignee of each of the associated JIRAs
+ # Note that a given issue may not have an assignee, or the assignee may not have a full name
+ for issue in issues:
+ try:
+ jira_issue = jira_client.issue(issue)
+ except JIRAError as e:
+ # Do not exit just because an issue is not found!
+ if e.status_code == 404:
+ warnings.append("Issue %s not found!" % issue)
+ continue
+ raise e
+ jira_assignee = jira_issue.fields.assignee
+ if jira_assignee:
+ user_name = jira_assignee.name
+ display_name = jira_assignee.displayName
+ if display_name:
+ candidates.append((display_name, "Full name of %s assignee %s" % (issue, user_name)))
+ else:
+ candidates.append((NOT_FOUND, "No full name found for %s assignee %" % (issue, user_name)))
+ else:
+ candidates.append((NOT_FOUND, "No assignee found for %s" % issue))
+ # Guard against special characters in candidate names
+ # Note that the candidate name may already be in unicode (JIRA returns this)
+ for i, (candidate, source) in enumerate(candidates):
+ try:
+ candidate = unicode(candidate, "UTF-8")
+ except TypeError:
+ # already in unicode
+ pass
+ candidate = unidecode.unidecode(candidate).strip()
+ candidates[i] = (candidate, source)
+ return candidates
+
+# Translate each invalid author by searching for possible candidates from Github and JIRA
+# In interactive mode, this script presents the user with a list of choices and have the user
+# select from this list. Additionally, the user may also choose to enter a custom name.
+# In non-interactive mode, this script picks the first valid author name from the candidates
+# If no such name exists, the original name is used (without the JIRA numbers).
+print "\n========================== Translating contributor list =========================="
+lines = contributors_file.readlines()
+contributions = []
+for i, line in enumerate(lines):
+ temp_author = line.strip(" * ").split(" -- ")[0]
+ print "Processing author %s (%d/%d)" % (temp_author, i + 1, len(lines))
+ if not temp_author:
+ error_msg = " ERROR: Expected the following format \" * -- \"\n"
+ error_msg += " ERROR: Actual = %s" % line
+ print error_msg
+ warnings.append(error_msg)
+ contributions.append(line)
+ continue
+ author = temp_author.split("/")[0]
+ # Use the local copy of known translations where possible
+ if author in known_translations:
+ line = line.replace(temp_author, known_translations[author])
+ elif not is_valid_author(author):
+ new_author = author
+ issues = temp_author.split("/")[1:]
+ candidates = generate_candidates(author, issues)
+ # Print out potential replacement candidates along with the sources, e.g.
+ # [X] No full name found for Github user andrewor14
+ # [X] No assignee found for SPARK-1763
+ # [0] Andrew Or - Full name of JIRA user andrewor14
+ # [1] Andrew Orso - Full name of SPARK-1444 assignee andrewor14
+ # [2] Andrew Ordall - Full name of SPARK-1663 assignee andrewor14
+ # [3] andrewor14 - Raw Github username
+ # [4] Custom
+ candidate_names = []
+ bad_prompts = [] # Prompts that can't actually be selected; print these first.
+ good_prompts = [] # Prompts that contain valid choices
+ for candidate, source in candidates:
+ if candidate == NOT_FOUND:
+ bad_prompts.append(" [X] %s" % source)
+ else:
+ index = len(candidate_names)
+ candidate_names.append(candidate)
+ good_prompts.append(" [%d] %s - %s" % (index, candidate, source))
+ raw_index = len(candidate_names)
+ custom_index = len(candidate_names) + 1
+ for p in bad_prompts: print p
+ if bad_prompts: print " ---"
+ for p in good_prompts: print p
+ # In interactive mode, additionally provide "custom" option and await user response
+ if INTERACTIVE_MODE:
+ print " [%d] %s - Raw Github username" % (raw_index, author)
+ print " [%d] Custom" % custom_index
+ response = raw_input(" Your choice: ")
+ last_index = custom_index
+ while not response.isdigit() or int(response) > last_index:
+ response = raw_input(" Please enter an integer between 0 and %d: " % last_index)
+ response = int(response)
+ if response == custom_index:
+ new_author = raw_input(" Please type a custom name for this author: ")
+ elif response != raw_index:
+ new_author = candidate_names[response]
+ # In non-interactive mode, just pick the first candidate
+ else:
+ valid_candidate_names = [name for name, _ in candidates\
+ if is_valid_author(name) and name != NOT_FOUND]
+ if valid_candidate_names:
+ new_author = valid_candidate_names[0]
+ # Finally, capitalize the author and replace the original one with it
+ # If the final replacement is still invalid, log a warning
+ if is_valid_author(new_author):
+ new_author = capitalize_author(new_author)
+ else:
+ warnings.append("Unable to find a valid name %s for author %s" % (author, temp_author))
+ print " * Replacing %s with %s" % (author, new_author)
+ # If we are in interactive mode, prompt the user whether we want to remember this new mapping
+ if INTERACTIVE_MODE and\
+ author not in known_translations and\
+ yesOrNoPrompt(" Add mapping %s -> %s to known translations file?" % (author, new_author)):
+ known_translations_file.write("%s - %s\n" % (author, new_author))
+ known_translations_file.flush()
+ line = line.replace(temp_author, author)
+ contributions.append(line)
+print "==================================================================================\n"
+contributors_file.close()
+known_translations_file.close()
+
+# Sort the contributions before writing them to the new file.
+# Additionally, check if there are any duplicate author rows.
+# This could happen if the same user has both a valid full
+# name (e.g. Andrew Or) and an invalid one (andrewor14).
+# If so, warn the user about this at the end.
+contributions.sort()
+all_authors = set()
+new_contributors_file_name = contributors_file_name + ".final"
+new_contributors_file = open(new_contributors_file_name, "w")
+for line in contributions:
+ author = line.strip(" * ").split(" -- ")[0]
+ if author in all_authors:
+ warnings.append("Detected duplicate author name %s. Please merge these manually." % author)
+ all_authors.add(author)
+ new_contributors_file.write(line)
+new_contributors_file.close()
+
+print "Translated contributors list successfully written to %s!" % new_contributors_file_name
+
+# Log any warnings encountered in the process
+if warnings:
+ print "\n========== Warnings encountered while translating the contributor list ==========="
+ for w in warnings: print w
+ print "Please manually correct these in the final contributors list at %s." % new_contributors_file_name
+ print "==================================================================================\n"
+
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index 02ac20984add9..dfa924d2aa0ba 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -214,15 +214,10 @@ def fix_version_from_branch(branch, versions):
return filter(lambda x: x.name.startswith(branch_ver), versions)[-1]
-def resolve_jira(title, merge_branches, comment):
+def resolve_jira_issue(merge_branches, comment, default_jira_id=""):
asf_jira = jira.client.JIRA({'server': JIRA_API_BASE},
basic_auth=(JIRA_USERNAME, JIRA_PASSWORD))
- default_jira_id = ""
- search = re.findall("SPARK-[0-9]{4,5}", title)
- if len(search) > 0:
- default_jira_id = search[0]
-
jira_id = raw_input("Enter a JIRA id [%s]: " % default_jira_id)
if jira_id == "":
jira_id = default_jira_id
@@ -280,6 +275,15 @@ def get_version_json(version_str):
print "Succesfully resolved %s with fixVersions=%s!" % (jira_id, fix_versions)
+def resolve_jira_issues(title, merge_branches, comment):
+ jira_ids = re.findall("SPARK-[0-9]{4,5}", title)
+
+ if len(jira_ids) == 0:
+ resolve_jira_issue(merge_branches, comment)
+ for jira_id in jira_ids:
+ resolve_jira_issue(merge_branches, comment, jira_id)
+
+
branches = get_json("%s/branches" % GITHUB_API_BASE)
branch_names = filter(lambda x: x.startswith("branch-"), [x['name'] for x in branches])
# Assumes branch names can be sorted lexicographically
@@ -338,7 +342,7 @@ def get_version_json(version_str):
if JIRA_USERNAME and JIRA_PASSWORD:
continue_maybe("Would you like to update an associated JIRA?")
jira_comment = "Issue resolved by pull request %s\n[%s/%s]" % (pr_num, GITHUB_BASE, pr_num)
- resolve_jira(title, merged_refs, jira_comment)
+ resolve_jira_issues(title, merged_refs, jira_comment)
else:
print "JIRA_USERNAME and JIRA_PASSWORD not set"
print "Exiting without trying to close the associated JIRA."
diff --git a/dev/mima b/dev/mima
index 40603166c21ae..bed5cd042634e 100755
--- a/dev/mima
+++ b/dev/mima
@@ -24,13 +24,13 @@ set -e
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
cd "$FWDIR"
-echo -e "q\n" | sbt/sbt oldDeps/update
+echo -e "q\n" | build/sbt oldDeps/update
rm -f .generated-mima*
-# Generate Mima Ignore is called twice, first with latest built jars
+# Generate Mima Ignore is called twice, first with latest built jars
# on the classpath and then again with previous version jars on the classpath.
# Because of a bug in GenerateMIMAIgnore that when old jars are ahead on classpath
-# it did not process the new classes (which are in assembly jar).
+# it did not process the new classes (which are in assembly jar).
./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore
export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`"
@@ -38,7 +38,7 @@ echo "SPARK_CLASSPATH=$SPARK_CLASSPATH"
./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore
-echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving"
+echo -e "q\n" | build/sbt mima-report-binary-issues | grep -v -e "info.*Resolving"
ret_val=$?
if [ $ret_val != 0 ]; then
diff --git a/dev/run-tests b/dev/run-tests
index f47fcf66ff7e7..2257a566bb1bb 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -21,8 +21,10 @@
FWDIR="$(cd "`dirname $0`"/..; pwd)"
cd "$FWDIR"
-# Remove work directory
+# Clean up work directory and caches
rm -rf ./work
+rm -rf ~/.ivy2/local/org.apache.spark
+rm -rf ~/.ivy2/cache/org.apache.spark
source "$FWDIR/dev/run-tests-codes.sh"
@@ -59,17 +61,17 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl"
{
if test -x "$JAVA_HOME/bin/java"; then
declare java_cmd="$JAVA_HOME/bin/java"
- else
+ else
declare java_cmd=java
fi
-
+
# We can't use sed -r -e due to OS X / BSD compatibility; hence, all the parentheses.
JAVA_VERSION=$(
$java_cmd -version 2>&1 \
| grep -e "^java version" --max-count=1 \
| sed "s/java version \"\(.*\)\.\(.*\)\.\(.*\)\"/\1\2/"
)
-
+
if [ "$JAVA_VERSION" -lt 18 ]; then
echo "[warn] Java 8 tests will not run because JDK version is < 1.8."
fi
@@ -79,7 +81,7 @@ export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl"
# Partial solution for SPARK-1455.
if [ -n "$AMPLAB_JENKINS" ]; then
git fetch origin master:master
-
+
sql_diffs=$(
git diff --name-only master \
| grep -e "^sql/" -e "^bin/spark-sql" -e "^sbin/start-thriftserver.sh"
@@ -93,7 +95,7 @@ if [ -n "$AMPLAB_JENKINS" ]; then
if [ -n "$sql_diffs" ]; then
echo "[info] Detected changes in SQL. Will run Hive test suite."
_RUN_SQL_TESTS=true
-
+
if [ -z "$non_sql_diffs" ]; then
echo "[info] Detected no changes except in SQL. Will only run SQL tests."
_SQL_TESTS_ONLY=true
@@ -139,20 +141,28 @@ echo "========================================================================="
CURRENT_BLOCK=$BLOCK_BUILD
{
- # We always build with Hive because the PySpark Spark SQL tests need it.
- BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
-
- echo "[info] Building Spark with these arguments: $BUILD_MVN_PROFILE_ARGS"
# NOTE: echo "q" is needed because sbt on encountering a build file with failure
- #+ (either resolution or compilation) prompts the user for input either q, r, etc
- #+ to quit or retry. This echo is there to make it not block.
- # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a
- #+ single argument!
+ # (either resolution or compilation) prompts the user for input either q, r, etc
+ # to quit or retry. This echo is there to make it not block.
+ # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a
+ # single argument!
# QUESTION: Why doesn't 'yes "q"' work?
# QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
+ # First build with Hive 0.12.0 to ensure patches do not break the Hive 0.12.0 build
+ HIVE_12_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver -Phive-0.12.0"
+ echo "[info] Compile with Hive 0.12.0"
echo -e "q\n" \
- | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly \
+ | build/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \
+ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
+
+ # Then build with default Hive version (0.13.1) because tests are based on this version
+ echo "[info] Compile with Hive 0.13.1"
+ rm -rf lib_managed
+ echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS"\
+ " -Phive -Phive-thriftserver"
+ echo -e "q\n" \
+ | build/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \
| grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
}
@@ -167,29 +177,29 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
# If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled.
# This must be a single argument, as it is.
if [ -n "$_RUN_SQL_TESTS" ]; then
- SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
+ SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver"
fi
-
+
if [ -n "$_SQL_TESTS_ONLY" ]; then
# This must be an array of individual arguments. Otherwise, having one long string
- #+ will be interpreted as a single test, which doesn't work.
- SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test")
+ # will be interpreted as a single test, which doesn't work.
+ SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test")
else
SBT_MAVEN_TEST_ARGS=("test")
fi
-
+
echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS ${SBT_MAVEN_TEST_ARGS[@]}"
-
+
# NOTE: echo "q" is needed because sbt on encountering a build file with failure
- #+ (either resolution or compilation) prompts the user for input either q, r, etc
- #+ to quit or retry. This echo is there to make it not block.
- # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a
- #+ single argument!
- #+ "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array.
+ # (either resolution or compilation) prompts the user for input either q, r, etc
+ # to quit or retry. This echo is there to make it not block.
+ # NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a
+ # single argument!
+ # "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array.
# QUESTION: Why doesn't 'yes "q"' work?
# QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
echo -e "q\n" \
- | sbt/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \
+ | build/sbt $SBT_MAVEN_PROFILES_ARGS "${SBT_MAVEN_TEST_ARGS[@]}" \
| grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
}
@@ -204,7 +214,7 @@ CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
echo ""
echo "========================================================================="
-echo "Detecting binary incompatibilites with MiMa"
+echo "Detecting binary incompatibilities with MiMa"
echo "========================================================================="
CURRENT_BLOCK=$BLOCK_MIMA
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index 451f3b771cc76..6a849e4f77207 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -53,9 +53,9 @@ function post_message () {
local message=$1
local data="{\"body\": \"$message\"}"
local HTTP_CODE_HEADER="HTTP Response Code: "
-
+
echo "Attempting to post to Github..."
-
+
local curl_output=$(
curl `#--dump-header -` \
--silent \
@@ -75,12 +75,12 @@ function post_message () {
echo " > data: ${data}" >&2
# exit $curl_status
fi
-
+
local api_response=$(
echo "${curl_output}" \
| grep -v -e "^${HTTP_CODE_HEADER}"
)
-
+
local http_code=$(
echo "${curl_output}" \
| grep -e "^${HTTP_CODE_HEADER}" \
@@ -92,12 +92,45 @@ function post_message () {
echo " > api_response: ${api_response}" >&2
echo " > data: ${data}" >&2
fi
-
+
if [ "$curl_status" -eq 0 ] && [ "$http_code" -eq "201" ]; then
echo " > Post successful."
fi
}
+function send_archived_logs () {
+ echo "Archiving unit tests logs..."
+
+ local log_files=$(
+ find .\
+ -name "unit-tests.log" -o\
+ -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\
+ -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\
+ -path "./sql/hive/target/HiveCompatibilitySuite.wrong"
+ )
+
+ if [ -z "$log_files" ]; then
+ echo "> No log files found." >&2
+ else
+ local log_archive="unit-tests-logs.tar.gz"
+ echo "$log_files" | xargs tar czf ${log_archive}
+
+ local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER}
+ local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive})
+ local scp_status="$?"
+
+ if [ "$scp_status" -ne 0 ]; then
+ echo "Failed to send archived unit tests logs to Jenkins master." >&2
+ echo "> scp_status: ${scp_status}" >&2
+ echo "> scp_output: ${scp_output}" >&2
+ else
+ echo "> Send successful."
+ fi
+
+ rm -f ${log_archive}
+ fi
+}
+
# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR
#+ and not anything else added to master since the PR was branched.
@@ -109,7 +142,7 @@ function post_message () {
else
merge_note=" * This patch merges cleanly."
fi
-
+
source_files=$(
git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \
| grep -v -e "\/test" `# ignore files in test directories` \
@@ -144,12 +177,12 @@ function post_message () {
# post start message
{
start_message="\
- [QA tests have started](${BUILD_URL}consoleFull) for \
+ [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \
PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})."
-
+
start_message="${start_message}\n${merge_note}"
# start_message="${start_message}\n${public_classes_note}"
-
+
post_message "$start_message"
}
@@ -159,7 +192,7 @@ function post_message () {
test_result="$?"
if [ "$test_result" -eq "124" ]; then
- fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** \
+ fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}consoleFull)** \
for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \
after a configured wait of \`${TESTS_TIMEOUT}\`."
@@ -187,15 +220,17 @@ function post_message () {
else
failing_test="some tests"
fi
-
+
test_result_note=" * This patch **fails $failing_test**."
fi
+
+ send_archived_logs
}
# post end message
{
result_message="\
- [QA tests have finished](${BUILD_URL}consoleFull) for \
+ [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}consoleFull) for \
PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})."
result_message="${result_message}\n${test_result_note}"
diff --git a/dev/scalastyle b/dev/scalastyle
index c3b356bcb3c06..86919227ed1ab 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -17,15 +17,12 @@
# limitations under the License.
#
-echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt
-# Check style with YARN alpha built too
-echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
- >> scalastyle.txt
+echo -e "q\n" | build/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt
# Check style with YARN built too
-echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \
+echo -e "q\n" | build/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 scalastyle \
>> scalastyle.txt
-ERRORS=$(cat scalastyle.txt | grep -e "\")
+ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}')
rm scalastyle.txt
if test ! -z "$ERRORS"; then
diff --git a/docs/README.md b/docs/README.md
index d2d58e435d4c4..8a54724c4beae 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -21,7 +21,7 @@ read those text files directly if you want. Start with index.md.
The markdown code can be compiled to HTML using the [Jekyll tool](http://jekyllrb.com).
`Jekyll` and a few dependencies must be installed for this to work. We recommend
-installing via the Ruby Gem dependency manager. Since the exact HTML output
+installing via the Ruby Gem dependency manager. Since the exact HTML output
varies between versions of Jekyll and its dependencies, we list specific versions here
in some cases:
@@ -43,7 +43,7 @@ You can modify the default Jekyll build as follows:
## Pygments
We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages,
-so you will also need to install that (it requires Python) by running `sudo easy_install Pygments`.
+so you will also need to install that (it requires Python) by running `sudo pip install Pygments`.
To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile
phase, use the following sytax:
@@ -53,9 +53,14 @@ phase, use the following sytax:
// supported languages too.
{% endhighlight %}
+## Sphinx
+
+We use Sphinx to generate Python API docs, so you will need to install it by running
+`sudo pip install sphinx`.
+
## API Docs (Scaladoc and Sphinx)
-You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory.
+You can build just the Spark scaladoc by running `build/sbt doc` from the SPARK_PROJECT_ROOT directory.
Similarly, you can build just the PySpark docs by running `make html` from the
SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as
@@ -63,7 +68,7 @@ public in `__init__.py`.
When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various
Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a
-jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it
+jekyll plugin to run `build/sbt doc` before building the site so if you haven't run it (recently) it
may take some time as it generates all of the scaladoc. The jekyll plugin also generates the
PySpark docs [Sphinx](http://sphinx-doc.org/).
diff --git a/docs/_config.yml b/docs/_config.yml
index f4bf242ac191b..e2db274e1f619 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -11,12 +11,12 @@ kramdown:
include:
- _static
-# These allow the documentation to be updated with nerw releases
+# These allow the documentation to be updated with newer releases
# of Spark, Scala, and Mesos.
-SPARK_VERSION: 1.0.0-SNAPSHOT
-SPARK_VERSION_SHORT: 1.0.0
+SPARK_VERSION: 1.3.0-SNAPSHOT
+SPARK_VERSION_SHORT: 1.3.0
SCALA_BINARY_VERSION: "2.10"
SCALA_VERSION: "2.10.4"
-MESOS_VERSION: 0.18.1
+MESOS_VERSION: 0.21.0
SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK
SPARK_GITHUB_URL: https://github.com/apache/spark
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 627ed37de4a9c..8841f7675d35e 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -33,7 +33,7 @@
+
+
+
+
+
+
org.apache.maven.pluginsmaven-shade-plugin
+ 2.2
+
+ false
+
+
+
+ org.spark-project.spark:unused
+
+ org.eclipse.jetty:jetty-io
+ org.eclipse.jetty:jetty-http
+ org.eclipse.jetty:jetty-continuation
+ org.eclipse.jetty:jetty-servlet
+ org.eclipse.jetty:jetty-plus
+ org.eclipse.jetty:jetty-security
+ org.eclipse.jetty:jetty-util
+ org.eclipse.jetty:jetty-server
+ com.google.guava:guava
+
+
+
+
+ org.eclipse.jetty
+ org.spark-project.jetty
+
+ org.eclipse.jetty.**
+
+
+
+ com.google.common
+ org.spark-project.guava
+
+
+ com/google/common/base/Absent*
+ com/google/common/base/Function
+ com/google/common/base/Optional*
+ com/google/common/base/Present*
+ com/google/common/base/Supplier
+
+
+
+ packageshade
-
- false
-
-
- org.spark-project.spark:unused
-
-
-
@@ -1089,6 +1428,15 @@
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
@@ -1150,8 +1498,31 @@
+
+ doclint-java8-disable
+
+ [1.8,)
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+
+ -Xdoclint:all -Xdoclint:-missing
+
+
+
+
+
+
+
+
hadoop-0.23
@@ -1171,6 +1542,7 @@
2.2.02.5.0
+ 0.98.7-hadoop2hadoop2
@@ -1181,6 +1553,8 @@
2.3.02.5.00.9.0
+ 0.98.7-hadoop2
+ 3.1.1hadoop2
@@ -1191,32 +1565,25 @@
2.4.02.5.00.9.0
+ 0.98.7-hadoop2
+ 3.1.1hadoop2
-
- yarn-alpha
-
- yarn
-
-
-
yarnyarn
+ network/yarnmapr3
-
- false
- 1.0.3-mapr-3.0.3
- 2.3.0-mapr-4.0.0-FCS
+ 2.4.1-mapr-14080.94.17-mapr-14053.4.5-mapr-1406
@@ -1224,12 +1591,9 @@
mapr4
-
- false
-
- 2.3.0-mapr-4.0.0-FCS
- 2.3.0-mapr-4.0.0-FCS
+ 2.4.1-mapr-1408
+ 2.4.1-mapr-14080.94.17-mapr-1405-4.0.0-FCS3.4.5-mapr-1406
@@ -1253,66 +1617,78 @@
-
- hadoop-provided
-
- false
-
-
-
- org.apache.hadoop
- hadoop-client
- provided
-
-
- org.apache.hadoop
- hadoop-yarn-api
- provided
-
-
- org.apache.hadoop
- hadoop-yarn-common
- provided
-
-
- org.apache.hadoop
- hadoop-yarn-server-web-proxy
- provided
-
-
- org.apache.hadoop
- hadoop-yarn-client
- provided
-
-
- org.apache.avro
- avro
- provided
-
-
- org.apache.avro
- avro-ipc
- provided
-
-
- org.apache.zookeeper
- zookeeper
- ${zookeeper.version}
- provided
-
-
+ hive-thriftserver
+
+ sql/hive-thriftserver
+
+
+
+ hive-0.12.0
+
+ 0.12.0-protobuf-2.5
+ 0.12.0
+ 10.4.2.0
+
+
+
+ hive-0.13.1
+
+ 0.13.1a
+ 0.13.1
+ 10.10.1.1
+
- hive
+ scala-2.10
- false
+ !scala-2.11
+
+ 2.10.4
+ 2.10
+ ${scala.version}
+ org.scala-lang
+
- sql/hive-thriftserver
+ external/kafka
+ external/kafka-assembly
+
+ scala-2.11
+
+ scala-2.11
+
+
+ 2.11.2
+ 2.11
+ 2.12
+ jline
+
+
+
+
+
+ flume-provided
+
+
+ hadoop-provided
+
+
+ hbase-provided
+
+
+ hive-provided
+
+
+ parquet-provided
+
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index d919b18e09855..f0cbf4e57b8c5 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -30,7 +30,7 @@ object MimaBuild {
def excludeMember(fullName: String) = Seq(
ProblemFilters.exclude[MissingMethodProblem](fullName),
- // Sometimes excluded methods have default arguments and
+ // Sometimes excluded methods have default arguments and
// they are translated into public methods/fields($default$) in generated
// bytecode. It is not possible to exhaustively list everything.
// But this should be okay.
@@ -91,9 +91,9 @@ object MimaBuild {
def mimaSettings(sparkHome: File, projectRef: ProjectRef) = {
val organization = "org.apache.spark"
- val previousSparkVersion = "1.1.0"
+ val previousSparkVersion = "1.2.0"
val fullId = "spark-" + projectRef.project + "_2.10"
- mimaDefaultSettings ++
+ mimaDefaultSettings ++
Seq(previousArtifact := Some(organization % fullId % previousSparkVersion),
binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value))
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index c58666af84f24..b17532c1d814c 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -33,6 +33,117 @@ import com.typesafe.tools.mima.core._
object MimaExcludes {
def excludes(version: String) =
version match {
+ case v if v.startsWith("1.3") =>
+ Seq(
+ MimaBuild.excludeSparkPackage("deploy"),
+ // These are needed if checking against the sbt build, since they are part of
+ // the maven-generated artifacts in the 1.2 build.
+ MimaBuild.excludeSparkPackage("unused"),
+ ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional")
+ ) ++ Seq(
+ // SPARK-2321
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.SparkStageInfoImpl.this"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.SparkStageInfo.submissionTime")
+ ) ++ Seq(
+ // SPARK-4614
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrices.randn"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrices.rand")
+ ) ++ Seq(
+ // SPARK-5321
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.transpose"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." +
+ "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.isTransposed"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.foreachActive")
+ ) ++ Seq(
+ // SPARK-5540
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.recommendation.ALS.solveLeastSquares"),
+ // SPARK-5536
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateBlock")
+ ) ++ Seq(
+ // SPARK-3325
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.api.java.JavaDStreamLike.print"),
+ // SPARK-2757
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](
+ "org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler." +
+ "removeAndGetProcessor")
+ ) ++ Seq(
+ // SPARK-5123 (SparkSQL data type change) - alpha component only
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](
+ "org.apache.spark.ml.feature.HashingTF.outputDataType"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](
+ "org.apache.spark.ml.feature.Tokenizer.outputDataType"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.ml.feature.Tokenizer.validateInputType"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.ml.classification.LogisticRegressionModel.validateAndTransformSchema"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](
+ "org.apache.spark.ml.classification.LogisticRegression.validateAndTransformSchema")
+ ) ++ Seq(
+ // SPARK-4014
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.TaskContext.taskAttemptId"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.TaskContext.attemptNumber")
+ ) ++ Seq(
+ // SPARK-5166 Spark SQL API stabilization
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate")
+ ) ++ Seq(
+ // SPARK-5270
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.isEmpty")
+ ) ++ Seq(
+ // SPARK-5430
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.treeReduce"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.api.java.JavaRDDLike.treeAggregate")
+ ) ++ Seq(
+ // SPARK-5297 Java FileStream do not work with custom key/values
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream")
+ ) ++ Seq(
+ // SPARK-5315 Spark Streaming Java API returns Scala DStream
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow")
+ ) ++ Seq(
+ // SPARK-5461 Graph should have isCheckpointed, getCheckpointFiles methods
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.graphx.Graph.getCheckpointFiles"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.graphx.Graph.isCheckpointed")
+ )
+
case v if v.startsWith("1.2") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
@@ -51,9 +162,16 @@ object MimaExcludes {
// MapStatus should be private[spark]
ProblemFilters.exclude[IncompatibleTemplateDefProblem](
"org.apache.spark.scheduler.MapStatus"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.network.netty.PathResolver"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.network.netty.client.BlockClientListener"),
+
// TaskContext was promoted to Abstract class
ProblemFilters.exclude[AbstractClassProblem](
- "org.apache.spark.TaskContext")
+ "org.apache.spark.TaskContext"),
+ ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+ "org.apache.spark.util.collection.SortDataFormat")
) ++ Seq(
// Adding new methods to the JavaRDDLike trait:
ProblemFilters.exclude[MissingMethodProblem](
@@ -66,6 +184,22 @@ object MimaExcludes {
"org.apache.spark.api.java.JavaRDDLike.foreachAsync"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.java.JavaRDDLike.collectAsync")
+ ) ++ Seq(
+ // SPARK-3822
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](
+ "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler")
+ ) ++ Seq(
+ // SPARK-1209
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"),
+ ProblemFilters.exclude[MissingTypesProblem](
+ "org.apache.spark.rdd.PairRDDFunctions")
+ ) ++ Seq(
+ // SPARK-4062
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this")
)
case v if v.startsWith("1.1") =>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 01a5b20e7c51d..93698efe84252 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -15,6 +15,8 @@
* limitations under the License.
*/
+import java.io.File
+
import scala.util.Properties
import scala.collection.JavaConversions._
@@ -22,8 +24,8 @@ import sbt._
import sbt.Classpaths.publishTask
import sbt.Keys._
import sbtunidoc.Plugin.genjavadocSettings
-import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings}
-import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
+import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
+import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys}
import net.virtualvoid.sbt.graph.Plugin.graphSettings
object BuildCommons {
@@ -31,18 +33,20 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
- sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt,
- streamingTwitter, streamingZeromq) =
+ sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka,
+ streamingMqtt, streamingTwitter, streamingZeromq) =
Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
- "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka",
- "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _))
+ "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
+ "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
+ "streaming-zeromq").map(ProjectRef(buildLocation, _))
- val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) =
- Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl")
- .map(ProjectRef(buildLocation, _))
+ val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl,
+ sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
+ "kinesis-asl").map(ProjectRef(buildLocation, _))
- val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples")
- .map(ProjectRef(buildLocation, _))
+ val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) =
+ Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly")
+ .map(ProjectRef(buildLocation, _))
val tools = ProjectRef(buildLocation, "tools")
// Root project.
@@ -67,8 +71,8 @@ object SparkBuild extends PomBuild {
profiles ++= Seq("spark-ganglia-lgpl")
}
if (Properties.envOrNone("SPARK_HIVE").isDefined) {
- println("NOTE: SPARK_HIVE is deprecated, please use -Phive flag.")
- profiles ++= Seq("hive")
+ println("NOTE: SPARK_HIVE is deprecated, please use -Phive and -Phive-thriftserver flags.")
+ profiles ++= Seq("hive", "hive-thriftserver")
}
Properties.envOrNone("SPARK_HADOOP_VERSION") match {
case Some(v) =>
@@ -78,25 +82,29 @@ object SparkBuild extends PomBuild {
case None =>
}
if (Properties.envOrNone("SPARK_YARN").isDefined) {
- if(isAlphaYarn) {
- println("NOTE: SPARK_YARN is deprecated, please use -Pyarn-alpha flag.")
- profiles ++= Seq("yarn-alpha")
- }
- else {
- println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.")
- profiles ++= Seq("yarn")
- }
+ println("NOTE: SPARK_YARN is deprecated, please use -Pyarn flag.")
+ profiles ++= Seq("yarn")
}
profiles
}
- override val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match {
+ override val profiles = {
+ val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match {
case None => backwardCompatibility
case Some(v) =>
if (backwardCompatibility.nonEmpty)
println("Note: We ignore environment variables, when use of profile is detected in " +
"conjunction with environment variable.")
v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq
+ }
+
+ if (System.getProperty("scala-2.11") == "") {
+ // To activate scala-2.11 profile, replace empty property value to non-empty value
+ // in the same way as Maven which handles -Dname as -Dname=true before executes build process.
+ // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082
+ System.setProperty("scala-2.11", "true")
+ }
+ profiles
}
Properties.envOrNone("SBT_MAVEN_PROPERTIES") match {
@@ -110,12 +118,13 @@ object SparkBuild extends PomBuild {
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
- lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ genjavadocSettings ++ Seq (
+ lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq (
javaHome := Properties.envOrNone("JAVA_HOME").map(file),
incOptions := incOptions.value.withNameHashing(true),
retrieveManaged := true,
retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
publishMavenStyle := true,
+ unidocGenjavadocVersion := "0.8",
resolvers += Resolver.mavenLocal,
otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))),
@@ -124,7 +133,12 @@ object SparkBuild extends PomBuild {
},
publishMavenStyle in MavenCompile := true,
publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal),
- publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn
+ publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn,
+
+ javacOptions in (Compile, doc) ++= {
+ val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3)
+ if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty
+ }
)
def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = {
@@ -134,14 +148,17 @@ object SparkBuild extends PomBuild {
// Note ordering of these settings matter.
/* Enable shared settings on all projects */
- (allProjects ++ optionallyEnabledProjects ++ assemblyProjects).foreach(enable(sharedSettings))
+ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools))
+ .foreach(enable(sharedSettings ++ ExludedDependencies.settings))
/* Enable tests settings for all projects except examples, assembly and tools */
(allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
// TODO: Add Sql to mima checks
allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl,
- streamingFlumeSink).contains(x)).foreach(x => enable(MimaBuild.mimaSettings(sparkHome, x))(x))
+ networkCommon, networkShuffle, networkYarn).contains(x)).foreach {
+ x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
+ }
/* Enable Assembly for all assembly projects */
assemblyProjects.foreach(enable(Assembly.settings))
@@ -174,6 +191,16 @@ object Flume {
lazy val settings = sbtavro.SbtAvro.avroSettings
}
+/**
+ This excludes library dependencies in sbt, which are specified in maven but are
+ not needed by sbt build.
+ */
+object ExludedDependencies {
+ lazy val settings = Seq(
+ libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") }
+ )
+}
+
/**
* Following project only exists to pull previous artifacts of Spark for generating
* Mima ignores. For more information see: SPARK 2071
@@ -184,7 +211,7 @@ object OldDeps {
def versionArtifact(id: String): Option[sbt.ModuleID] = {
val fullId = id + "_2.10"
- Some("org.apache.spark" % fullId % "1.1.0")
+ Some("org.apache.spark" % fullId % "1.2.0")
}
def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq(
@@ -217,10 +244,11 @@ object SQL {
|import org.apache.spark.sql.catalyst.expressions._
|import org.apache.spark.sql.catalyst.plans.logical._
|import org.apache.spark.sql.catalyst.rules._
- |import org.apache.spark.sql.catalyst.types._
|import org.apache.spark.sql.catalyst.util._
+ |import org.apache.spark.sql.Dsl._
|import org.apache.spark.sql.execution
|import org.apache.spark.sql.test.TestSQLContext._
+ |import org.apache.spark.sql.types._
|import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin,
cleanupCommands in console := "sparkContext.stop()"
)
@@ -230,6 +258,8 @@ object Hive {
lazy val settings = Seq(
javaOptions += "-XX:MaxPermSize=1g",
+ // Specially disable assertions since some Hive tests fail them
+ javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
// Multiple queries rely on the TestHive singleton. See comments there for more details.
parallelExecution in Test := false,
// Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings
@@ -245,13 +275,17 @@ object Hive {
|import org.apache.spark.sql.catalyst.expressions._
|import org.apache.spark.sql.catalyst.plans.logical._
|import org.apache.spark.sql.catalyst.rules._
- |import org.apache.spark.sql.catalyst.types._
|import org.apache.spark.sql.catalyst.util._
|import org.apache.spark.sql.execution
|import org.apache.spark.sql.hive._
|import org.apache.spark.sql.hive.test.TestHive._
+ |import org.apache.spark.sql.types._
|import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin,
- cleanupCommands in console := "sparkContext.stop()"
+ cleanupCommands in console := "sparkContext.stop()",
+ // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce
+ // in order to generate golden files. This is only required for developers who are adding new
+ // new query tests.
+ fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") }
)
}
@@ -260,10 +294,22 @@ object Assembly {
import sbtassembly.Plugin._
import AssemblyKeys._
+ val hadoopVersion = taskKey[String]("The version of hadoop that spark is compiled against.")
+
lazy val settings = assemblySettings ++ Seq(
test in assembly := {},
- jarName in assembly <<= (version, moduleName) map { (v, mName) => mName + "-"+v + "-hadoop" +
- Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" },
+ hadoopVersion := {
+ sys.props.get("hadoop.version")
+ .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String])
+ },
+ jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) =>
+ if (mName.contains("streaming-kafka-assembly")) {
+ // This must match the same name used in maven (see external/kafka-assembly/pom.xml)
+ s"${mName}-${v}.jar"
+ } else {
+ s"${mName}-${v}-hadoop${hv}.jar"
+ }
+ },
mergeStrategy in assembly := {
case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
@@ -274,7 +320,6 @@ object Assembly {
case _ => MergeStrategy.first
}
)
-
}
object Unidoc {
@@ -292,9 +337,9 @@ object Unidoc {
publish := {},
unidocProjectFilter in(ScalaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn),
unidocProjectFilter in(JavaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha),
+ inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, catalyst, streamingFlumeSink, yarn),
// Skip class names containing $ and some internal packages in Javadocs
unidocAllSources in (JavaUnidoc, unidoc) := {
@@ -322,7 +367,10 @@ object Unidoc {
"mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg",
"mllib.linalg.distributed", "mllib.optimization", "mllib.rdd", "mllib.recommendation",
"mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration",
- "mllib.tree.impurity", "mllib.tree.model", "mllib.util"
+ "mllib.tree.impurity", "mllib.tree.model", "mllib.util",
+ "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation",
+ "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss",
+ "ml", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", "ml.tuning"
),
"-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"),
"-noqualifier", "java.lang"
@@ -340,13 +388,19 @@ object TestSettings {
javaOptions in Test += "-Dspark.testing=1",
javaOptions in Test += "-Dspark.port.maxRetries=100",
javaOptions in Test += "-Dspark.ui.enabled=false",
+ javaOptions in Test += "-Dspark.ui.showConsoleProgress=false",
+ javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
.map { case (k,v) => s"-D$k=$v" }.toSeq,
+ javaOptions in Test += "-ea",
javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
.split(" ").toSeq,
+ // This places test scope jars on the classpath of executors during tests.
+ javaOptions in Test +=
+ "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files.
+ map(_.getAbsolutePath).mkString(":").stripSuffix(":"),
javaOptions += "-Xmx3g",
-
// Show full stack trace and duration in test cases.
testOptions in Test += Tests.Argument("-oDF"),
testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"),
diff --git a/project/build.properties b/project/build.properties
index 32a3aeefaf9fb..064ec843da9ea 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-sbt.version=0.13.6
+sbt.version=0.13.7
diff --git a/project/plugins.sbt b/project/plugins.sbt
index 9d50a50b109af..ee45b6a51905e 100644
--- a/project/plugins.sbt
+++ b/project/plugins.sbt
@@ -19,7 +19,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4")
-addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.5.0")
+addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0")
addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6")
diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala
index 3ef2d5451da0d..8863f272da415 100644
--- a/project/project/SparkPluginBuild.scala
+++ b/project/project/SparkPluginBuild.scala
@@ -26,7 +26,7 @@ import sbt.Keys._
object SparkPluginDef extends Build {
lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader)
lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings)
- lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git")
+ lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id")
// There is actually no need to publish this artifact.
def styleSettings = Defaults.defaultSettings ++ Seq (
diff --git a/python/docs/conf.py b/python/docs/conf.py
index e58d97ae6a746..b00dce95d65b4 100644
--- a/python/docs/conf.py
+++ b/python/docs/conf.py
@@ -55,9 +55,9 @@
# built documents.
#
# The short X.Y version.
-version = '1.2-SNAPSHOT'
+version = '1.3-SNAPSHOT'
# The full version, including alpha/beta/rc tags.
-release = '1.2-SNAPSHOT'
+release = '1.3-SNAPSHOT'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
diff --git a/python/docs/epytext.py b/python/docs/epytext.py
index 19fefbfc057a4..e884d5e6b19c7 100644
--- a/python/docs/epytext.py
+++ b/python/docs/epytext.py
@@ -1,7 +1,7 @@
import re
RULES = (
- (r"<[\w.]+>", r""),
+ (r"<(!BLANKLINE)[\w.]+>", r""),
(r"L{([\w.()]+)}", r":class:`\1`"),
(r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"),
(r"C{([\w.()]+)}", r":class:`\1`"),
diff --git a/python/docs/index.rst b/python/docs/index.rst
index 703bef644de28..d150de9d5c502 100644
--- a/python/docs/index.rst
+++ b/python/docs/index.rst
@@ -14,6 +14,7 @@ Contents:
pyspark
pyspark.sql
pyspark.streaming
+ pyspark.ml
pyspark.mllib
diff --git a/python/docs/make.bat b/python/docs/make.bat
index c011e82b4a35a..cc29acdc19686 100644
--- a/python/docs/make.bat
+++ b/python/docs/make.bat
@@ -1,6 +1,6 @@
-@ECHO OFF
-
-rem This is the entry point for running Sphinx documentation. To avoid polluting the
-rem environment, it just launches a new cmd to do the real work.
-
-cmd /V /E /C %~dp0make2.bat %*
+@ECHO OFF
+
+rem This is the entry point for running Sphinx documentation. To avoid polluting the
+rem environment, it just launches a new cmd to do the real work.
+
+cmd /V /E /C %~dp0make2.bat %*
diff --git a/python/docs/make2.bat b/python/docs/make2.bat
index 7bcaeafad13d7..05d22eb5cdd23 100644
--- a/python/docs/make2.bat
+++ b/python/docs/make2.bat
@@ -1,243 +1,243 @@
-@ECHO OFF
-
-REM Command file for Sphinx documentation
-
-
-if "%SPHINXBUILD%" == "" (
- set SPHINXBUILD=sphinx-build
-)
-set BUILDDIR=_build
-set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .
-set I18NSPHINXOPTS=%SPHINXOPTS% .
-if NOT "%PAPER%" == "" (
- set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
- set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
-)
-
-if "%1" == "" goto help
-
-if "%1" == "help" (
- :help
- echo.Please use `make ^` where ^ is one of
- echo. html to make standalone HTML files
- echo. dirhtml to make HTML files named index.html in directories
- echo. singlehtml to make a single large HTML file
- echo. pickle to make pickle files
- echo. json to make JSON files
- echo. htmlhelp to make HTML files and a HTML help project
- echo. qthelp to make HTML files and a qthelp project
- echo. devhelp to make HTML files and a Devhelp project
- echo. epub to make an epub
- echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
- echo. text to make text files
- echo. man to make manual pages
- echo. texinfo to make Texinfo files
- echo. gettext to make PO message catalogs
- echo. changes to make an overview over all changed/added/deprecated items
- echo. xml to make Docutils-native XML files
- echo. pseudoxml to make pseudoxml-XML files for display purposes
- echo. linkcheck to check all external links for integrity
- echo. doctest to run all doctests embedded in the documentation if enabled
- goto end
-)
-
-if "%1" == "clean" (
- for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
- del /q /s %BUILDDIR%\*
- goto end
-)
-
-
-%SPHINXBUILD% 2> nul
-if errorlevel 9009 (
- echo.
- echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
- echo.installed, then set the SPHINXBUILD environment variable to point
- echo.to the full path of the 'sphinx-build' executable. Alternatively you
- echo.may add the Sphinx directory to PATH.
- echo.
- echo.If you don't have Sphinx installed, grab it from
- echo.http://sphinx-doc.org/
- exit /b 1
-)
-
-if "%1" == "html" (
- %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/html.
- goto end
-)
-
-if "%1" == "dirhtml" (
- %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
- goto end
-)
-
-if "%1" == "singlehtml" (
- %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
- goto end
-)
-
-if "%1" == "pickle" (
- %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can process the pickle files.
- goto end
-)
-
-if "%1" == "json" (
- %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can process the JSON files.
- goto end
-)
-
-if "%1" == "htmlhelp" (
- %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can run HTML Help Workshop with the ^
-.hhp project file in %BUILDDIR%/htmlhelp.
- goto end
-)
-
-if "%1" == "qthelp" (
- %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; now you can run "qcollectiongenerator" with the ^
-.qhcp project file in %BUILDDIR%/qthelp, like this:
- echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp
- echo.To view the help file:
- echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc
- goto end
-)
-
-if "%1" == "devhelp" (
- %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished.
- goto end
-)
-
-if "%1" == "epub" (
- %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The epub file is in %BUILDDIR%/epub.
- goto end
-)
-
-if "%1" == "latex" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "latexpdf" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- cd %BUILDDIR%/latex
- make all-pdf
- cd %BUILDDIR%/..
- echo.
- echo.Build finished; the PDF files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "latexpdfja" (
- %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
- cd %BUILDDIR%/latex
- make all-pdf-ja
- cd %BUILDDIR%/..
- echo.
- echo.Build finished; the PDF files are in %BUILDDIR%/latex.
- goto end
-)
-
-if "%1" == "text" (
- %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The text files are in %BUILDDIR%/text.
- goto end
-)
-
-if "%1" == "man" (
- %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The manual pages are in %BUILDDIR%/man.
- goto end
-)
-
-if "%1" == "texinfo" (
- %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
- goto end
-)
-
-if "%1" == "gettext" (
- %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
- goto end
-)
-
-if "%1" == "changes" (
- %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
- if errorlevel 1 exit /b 1
- echo.
- echo.The overview file is in %BUILDDIR%/changes.
- goto end
-)
-
-if "%1" == "linkcheck" (
- %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
- if errorlevel 1 exit /b 1
- echo.
- echo.Link check complete; look for any errors in the above output ^
-or in %BUILDDIR%/linkcheck/output.txt.
- goto end
-)
-
-if "%1" == "doctest" (
- %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
- if errorlevel 1 exit /b 1
- echo.
- echo.Testing of doctests in the sources finished, look at the ^
-results in %BUILDDIR%/doctest/output.txt.
- goto end
-)
-
-if "%1" == "xml" (
- %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The XML files are in %BUILDDIR%/xml.
- goto end
-)
-
-if "%1" == "pseudoxml" (
- %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
- if errorlevel 1 exit /b 1
- echo.
- echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
- goto end
-)
-
-:end
+@ECHO OFF
+
+REM Command file for Sphinx documentation
+
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set BUILDDIR=_build
+set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% .
+set I18NSPHINXOPTS=%SPHINXOPTS% .
+if NOT "%PAPER%" == "" (
+ set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS%
+ set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS%
+)
+
+if "%1" == "" goto help
+
+if "%1" == "help" (
+ :help
+ echo.Please use `make ^` where ^ is one of
+ echo. html to make standalone HTML files
+ echo. dirhtml to make HTML files named index.html in directories
+ echo. singlehtml to make a single large HTML file
+ echo. pickle to make pickle files
+ echo. json to make JSON files
+ echo. htmlhelp to make HTML files and a HTML help project
+ echo. qthelp to make HTML files and a qthelp project
+ echo. devhelp to make HTML files and a Devhelp project
+ echo. epub to make an epub
+ echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter
+ echo. text to make text files
+ echo. man to make manual pages
+ echo. texinfo to make Texinfo files
+ echo. gettext to make PO message catalogs
+ echo. changes to make an overview over all changed/added/deprecated items
+ echo. xml to make Docutils-native XML files
+ echo. pseudoxml to make pseudoxml-XML files for display purposes
+ echo. linkcheck to check all external links for integrity
+ echo. doctest to run all doctests embedded in the documentation if enabled
+ goto end
+)
+
+if "%1" == "clean" (
+ for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i
+ del /q /s %BUILDDIR%\*
+ goto end
+)
+
+
+%SPHINXBUILD% 2> nul
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.http://sphinx-doc.org/
+ exit /b 1
+)
+
+if "%1" == "html" (
+ %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The HTML pages are in %BUILDDIR%/html.
+ goto end
+)
+
+if "%1" == "dirhtml" (
+ %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml.
+ goto end
+)
+
+if "%1" == "singlehtml" (
+ %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml.
+ goto end
+)
+
+if "%1" == "pickle" (
+ %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can process the pickle files.
+ goto end
+)
+
+if "%1" == "json" (
+ %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can process the JSON files.
+ goto end
+)
+
+if "%1" == "htmlhelp" (
+ %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can run HTML Help Workshop with the ^
+.hhp project file in %BUILDDIR%/htmlhelp.
+ goto end
+)
+
+if "%1" == "qthelp" (
+ %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; now you can run "qcollectiongenerator" with the ^
+.qhcp project file in %BUILDDIR%/qthelp, like this:
+ echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp
+ echo.To view the help file:
+ echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc
+ goto end
+)
+
+if "%1" == "devhelp" (
+ %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished.
+ goto end
+)
+
+if "%1" == "epub" (
+ %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The epub file is in %BUILDDIR%/epub.
+ goto end
+)
+
+if "%1" == "latex" (
+ %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished; the LaTeX files are in %BUILDDIR%/latex.
+ goto end
+)
+
+if "%1" == "latexpdf" (
+ %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
+ cd %BUILDDIR%/latex
+ make all-pdf
+ cd %BUILDDIR%/..
+ echo.
+ echo.Build finished; the PDF files are in %BUILDDIR%/latex.
+ goto end
+)
+
+if "%1" == "latexpdfja" (
+ %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex
+ cd %BUILDDIR%/latex
+ make all-pdf-ja
+ cd %BUILDDIR%/..
+ echo.
+ echo.Build finished; the PDF files are in %BUILDDIR%/latex.
+ goto end
+)
+
+if "%1" == "text" (
+ %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The text files are in %BUILDDIR%/text.
+ goto end
+)
+
+if "%1" == "man" (
+ %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The manual pages are in %BUILDDIR%/man.
+ goto end
+)
+
+if "%1" == "texinfo" (
+ %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo.
+ goto end
+)
+
+if "%1" == "gettext" (
+ %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The message catalogs are in %BUILDDIR%/locale.
+ goto end
+)
+
+if "%1" == "changes" (
+ %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.The overview file is in %BUILDDIR%/changes.
+ goto end
+)
+
+if "%1" == "linkcheck" (
+ %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Link check complete; look for any errors in the above output ^
+or in %BUILDDIR%/linkcheck/output.txt.
+ goto end
+)
+
+if "%1" == "doctest" (
+ %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Testing of doctests in the sources finished, look at the ^
+results in %BUILDDIR%/doctest/output.txt.
+ goto end
+)
+
+if "%1" == "xml" (
+ %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The XML files are in %BUILDDIR%/xml.
+ goto end
+)
+
+if "%1" == "pseudoxml" (
+ %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml
+ if errorlevel 1 exit /b 1
+ echo.
+ echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml.
+ goto end
+)
+
+:end
diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
new file mode 100644
index 0000000000000..f10d1339a9a8f
--- /dev/null
+++ b/python/docs/pyspark.ml.rst
@@ -0,0 +1,29 @@
+pyspark.ml package
+=====================
+
+Submodules
+----------
+
+pyspark.ml module
+-----------------
+
+.. automodule:: pyspark.ml
+ :members:
+ :undoc-members:
+ :inherited-members:
+
+pyspark.ml.feature module
+-------------------------
+
+.. automodule:: pyspark.ml.feature
+ :members:
+ :undoc-members:
+ :inherited-members:
+
+pyspark.ml.classification module
+--------------------------------
+
+.. automodule:: pyspark.ml.classification
+ :members:
+ :undoc-members:
+ :inherited-members:
diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst
index e81be3b6cb796..0df12c49ad033 100644
--- a/python/docs/pyspark.rst
+++ b/python/docs/pyspark.rst
@@ -9,6 +9,7 @@ Subpackages
pyspark.sql
pyspark.streaming
+ pyspark.ml
pyspark.mllib
Contents
diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst
index 5024d694b668f..f08185627d0bc 100644
--- a/python/docs/pyspark.streaming.rst
+++ b/python/docs/pyspark.streaming.rst
@@ -1,5 +1,5 @@
pyspark.streaming module
-==================
+========================
Module contents
---------------
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index e39e6514d77a1..d3efcdf221d82 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -37,16 +37,6 @@
"""
-# The following block allows us to import python's random instead of mllib.random for scripts in
-# mllib that depend on top level pyspark packages, which transitively depend on python's random.
-# Since Python's import logic looks for modules in the current package first, we eliminate
-# mllib.random as a candidate for C{import random} by removing the first search path, the script's
-# location, in order to force the loader to look in Python's top-level modules for C{random}.
-import sys
-s = sys.path.pop(0)
-import random
-sys.path.insert(0, s)
-
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.rdd import RDD
@@ -55,6 +45,7 @@
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
+from pyspark.profiler import Profiler, BasicProfiler
# for back compatibility
from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row
@@ -62,4 +53,5 @@
__all__ = [
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
+ "Profiler", "BasicProfiler",
]
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index b8cdbbe3cf2b6..ccbca67656c8d 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -215,21 +215,6 @@ def addInPlace(self, value1, value2):
COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j)
-class PStatsParam(AccumulatorParam):
- """PStatsParam is used to merge pstats.Stats"""
-
- @staticmethod
- def zero(value):
- return None
-
- @staticmethod
- def addInPlace(value1, value2):
- if value1 is None:
- return value2
- value1.add(value2)
- return value1
-
-
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
"""
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index f124dc6c07575..6b8a8b256a891 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -15,21 +15,10 @@
# limitations under the License.
#
-"""
->>> from pyspark.context import SparkContext
->>> sc = SparkContext('local', 'test')
->>> b = sc.broadcast([1, 2, 3, 4, 5])
->>> b.value
-[1, 2, 3, 4, 5]
->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
-[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
->>> b.unpersist()
-
->>> large_broadcast = sc.broadcast(list(range(10000)))
-"""
import os
-
-from pyspark.serializers import CompressedSerializer, PickleSerializer
+import cPickle
+import gc
+from tempfile import NamedTemporaryFile
__all__ = ['Broadcast']
@@ -49,44 +38,88 @@ def _from_id(bid):
class Broadcast(object):
"""
- A broadcast variable created with
- L{SparkContext.broadcast()}.
+ A broadcast variable created with L{SparkContext.broadcast()}.
Access its value through C{.value}.
+
+ Examples:
+
+ >>> from pyspark.context import SparkContext
+ >>> sc = SparkContext('local', 'test')
+ >>> b = sc.broadcast([1, 2, 3, 4, 5])
+ >>> b.value
+ [1, 2, 3, 4, 5]
+ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+ [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+ >>> b.unpersist()
+
+ >>> large_broadcast = sc.broadcast(range(10000))
"""
- def __init__(self, bid, value, java_broadcast=None,
- pickle_registry=None, path=None):
+ def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
"""
- Should not be called directly by users -- use
- L{SparkContext.broadcast()}
+ Should not be called directly by users -- use L{SparkContext.broadcast()}
instead.
"""
- self.bid = bid
- if path is None:
- self._value = value
- self._jbroadcast = java_broadcast
- self._pickle_registry = pickle_registry
- self.path = path
+ if sc is not None:
+ f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
+ self._path = self.dump(value, f)
+ self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path)
+ self._pickle_registry = pickle_registry
+ else:
+ self._jbroadcast = None
+ self._path = path
+
+ def dump(self, value, f):
+ if isinstance(value, basestring):
+ if isinstance(value, unicode):
+ f.write('U')
+ value = value.encode('utf8')
+ else:
+ f.write('S')
+ f.write(value)
+ else:
+ f.write('P')
+ cPickle.dump(value, f, 2)
+ f.close()
+ return f.name
+
+ def load(self, path):
+ with open(path, 'rb', 1 << 20) as f:
+ flag = f.read(1)
+ data = f.read()
+ if flag == 'P':
+ # cPickle.loads() may create lots of objects, disable GC
+ # temporary for better performance
+ gc.disable()
+ try:
+ return cPickle.loads(data)
+ finally:
+ gc.enable()
+ else:
+ return data.decode('utf8') if flag == 'U' else data
@property
def value(self):
""" Return the broadcasted value
"""
- if not hasattr(self, "_value") and self.path is not None:
- ser = CompressedSerializer(PickleSerializer())
- self._value = ser.load_stream(open(self.path)).next()
+ if not hasattr(self, "_value") and self._path is not None:
+ self._value = self.load(self._path)
return self._value
def unpersist(self, blocking=False):
"""
Delete cached copies of this broadcast on the executors.
"""
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be unpersisted in driver")
self._jbroadcast.unpersist(blocking)
- os.unlink(self.path)
+ os.unlink(self._path)
def __reduce__(self):
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be serialized in driver")
self._pickle_registry.add(self)
- return (_from_id, (self.bid, ))
+ return _from_id, (self._jbroadcast.id(),)
if __name__ == "__main__":
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 8d27ccb95f82c..bf1f61c8504ed 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -20,7 +20,6 @@
import sys
from threading import Lock
from tempfile import NamedTemporaryFile
-import atexit
from pyspark import accumulators
from pyspark.accumulators import Accumulator
@@ -29,10 +28,11 @@
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer, CompressedSerializer, AutoBatchedSerializer
+ PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
+from pyspark.profiler import ProfilerCollector, BasicProfiler
from py4j.java_collections import ListConverter
@@ -43,7 +43,6 @@
# These are special default configs for PySpark, they will overwrite
# the default ones for Spark if they are not configured by user.
DEFAULT_CONFIGS = {
- "spark.serializer": "org.apache.spark.serializer.KryoSerializer",
"spark.serializer.objectStreamReset": 100,
"spark.rdd.compress": True,
}
@@ -64,11 +63,10 @@ class SparkContext(object):
_active_spark_context = None
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
- _default_batch_size_for_serialized_input = 10
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
- gateway=None, jsc=None):
+ gateway=None, jsc=None, profiler_cls=BasicProfiler):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
@@ -90,6 +88,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
:param conf: A L{SparkConf} object setting Spark properties.
:param gateway: Use an existing gateway and JVM, otherwise a new JVM
will be instantiated.
+ :param jsc: The JavaSparkContext instance (optional).
+ :param profiler_cls: A class of custom Profiler used to do profiling
+ (default is pyspark.profiler.BasicProfiler).
>>> from pyspark.context import SparkContext
@@ -104,21 +105,19 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf, jsc)
+ conf, jsc, profiler_cls)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
- conf, jsc):
+ conf, jsc, profiler_cls):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
self._unbatched_serializer = serializer
- if batchSize == 1:
- self.serializer = self._unbatched_serializer
- elif batchSize == 0:
+ if batchSize == 0:
self.serializer = AutoBatchedSerializer(self._unbatched_serializer)
else:
self.serializer = BatchedSerializer(self._unbatched_serializer,
@@ -193,10 +192,15 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
self._temp_dir = \
- self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
+ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "pyspark") \
+ .getAbsolutePath()
# profiling stats collected for each PythonRDD
- self._profile_stats = []
+ if self._conf.get("spark.python.profile", "false") == "true":
+ dump_path = self._conf.get("spark.python.profile.dump", None)
+ self.profiler_collector = ProfilerCollector(profiler_cls, dump_path)
+ else:
+ self.profiler_collector = None
def _initialize_context(self, jconf):
"""
@@ -233,6 +237,14 @@ def _ensure_initialized(cls, instance=None, gateway=None):
else:
SparkContext._active_spark_context = instance
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle SparkContext, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to reference SparkContext from a broadcast "
+ "variable, action, or transforamtion. SparkContext can only be used on the driver, "
+ "not in code that it run on workers. For more information, see SPARK-5063."
+ )
+
def __enter__(self):
"""
Enable 'with SparkContext(...) as sc: app(sc)' syntax.
@@ -293,12 +305,29 @@ def stop(self):
def parallelize(self, c, numSlices=None):
"""
- Distribute a local Python collection to form an RDD.
+ Distribute a local Python collection to form an RDD. Using xrange
+ is recommended if the input represents a range for performance.
- >>> sc.parallelize(range(5), 5).glom().collect()
- [[0], [1], [2], [3], [4]]
+ >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect()
+ [[0], [2], [3], [4], [6]]
+ >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect()
+ [[], [0], [], [2], [4]]
"""
- numSlices = numSlices or self.defaultParallelism
+ numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism
+ if isinstance(c, xrange):
+ size = len(c)
+ if size == 0:
+ return self.parallelize([], numSlices)
+ step = c[1] - c[0] if size > 1 else 1
+ start0 = c[0]
+
+ def getStart(split):
+ return start0 + (split * size / numSlices) * step
+
+ def f(split, iterator):
+ return xrange(getStart(split), getStart(split + 1), step)
+
+ return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
@@ -306,12 +335,8 @@ def parallelize(self, c, numSlices=None):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
- batchSize = min(len(c) // numSlices, self._batchSize)
- if batchSize > 1:
- serializer = BatchedSerializer(self._unbatched_serializer,
- batchSize)
- else:
- serializer = self._unbatched_serializer
+ batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
+ serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
serializer.dump_stream(c, tempFile)
tempFile.close()
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
@@ -329,8 +354,7 @@ def pickleFile(self, name, minPartitions=None):
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
minPartitions = minPartitions or self.defaultMinPartitions
- return RDD(self._jsc.objectFile(name, minPartitions), self,
- BatchedSerializer(PickleSerializer()))
+ return RDD(self._jsc.objectFile(name, minPartitions), self)
def textFile(self, name, minPartitions=None, use_unicode=True):
"""
@@ -397,6 +421,36 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True):
return RDD(self._jsc.wholeTextFiles(path, minPartitions), self,
PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode)))
+ def binaryFiles(self, path, minPartitions=None):
+ """
+ .. note:: Experimental
+
+ Read a directory of binary files from HDFS, a local file system
+ (available on all nodes), or any Hadoop-supported file system URI
+ as a byte array. Each file is read as a single record and returned
+ in a key-value pair, where the key is the path of each file, the
+ value is the content of each file.
+
+ Note: Small files are preferred, large file is also allowable, but
+ may cause bad performance.
+ """
+ minPartitions = minPartitions or self.defaultMinPartitions
+ return RDD(self._jsc.binaryFiles(path, minPartitions), self,
+ PairDeserializer(UTF8Deserializer(), NoOpSerializer()))
+
+ def binaryRecords(self, path, recordLength):
+ """
+ .. note:: Experimental
+
+ Load data from a flat binary file, assuming each record is a set of numbers
+ with the specified numerical format (see ByteBuffer), and the number of
+ bytes per record is constant.
+
+ :param path: Directory to the input data files
+ :param recordLength: The length at which to split the records
+ """
+ return RDD(self._jsc.binaryRecords(path, recordLength), self, NoOpSerializer())
+
def _dictToJavaMap(self, d):
jm = self._jvm.java.util.HashMap()
if not d:
@@ -406,7 +460,7 @@ def _dictToJavaMap(self, d):
return jm
def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
- valueConverter=None, minSplits=None, batchSize=None):
+ valueConverter=None, minSplits=None, batchSize=0):
"""
Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -428,17 +482,15 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
:param minSplits: minimum splits in dataset
(default min(2, sc.defaultParallelism))
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass,
keyConverter, valueConverter, minSplits, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=None):
+ valueConverter=None, conf=None, batchSize=0):
"""
Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -459,18 +511,16 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv
:param conf: Hadoop configuration, passed in as a dict
(None by default)
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter,
jconf, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=None):
+ valueConverter=None, conf=None, batchSize=0):
"""
Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
Hadoop configuration, which is passed in as a Python dict.
@@ -488,18 +538,16 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N
:param conf: Hadoop configuration, passed in as a dict
(None by default)
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter,
jconf, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=None):
+ valueConverter=None, conf=None, batchSize=0):
"""
Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -520,18 +568,16 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=
:param conf: Hadoop configuration, passed in as a dict
(None by default)
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter,
jconf, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=None):
+ valueConverter=None, conf=None, batchSize=0):
"""
Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
Hadoop configuration, which is passed in as a Python dict.
@@ -549,15 +595,13 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
:param conf: Hadoop configuration, passed in as a dict
(None by default)
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter,
jconf, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
@@ -596,14 +640,7 @@ def broadcast(self, value):
object for reading it in distributed functions. The variable will
be sent to each cluster only once.
"""
- ser = CompressedSerializer(PickleSerializer())
- # pass large object by py4j is very slow and need much memory
- tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- ser.dump_stream([value], tempFile)
- tempFile.close()
- jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name)
- return Broadcast(jbroadcast.id(), None, jbroadcast,
- self._pickled_broadcast_vars, tempFile.name)
+ return Broadcast(self, value, self._pickled_broadcast_vars)
def accumulator(self, value, accum_param=None):
"""
@@ -797,39 +834,14 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
return list(mappedRDD._collect_iterator_through_file(it))
- def _add_profile(self, id, profileAcc):
- if not self._profile_stats:
- dump_path = self._conf.get("spark.python.profile.dump")
- if dump_path:
- atexit.register(self.dump_profiles, dump_path)
- else:
- atexit.register(self.show_profiles)
-
- self._profile_stats.append([id, profileAcc, False])
-
def show_profiles(self):
""" Print the profile stats to stdout """
- for i, (id, acc, showed) in enumerate(self._profile_stats):
- stats = acc.value
- if not showed and stats:
- print "=" * 60
- print "Profile of RDD" % id
- print "=" * 60
- stats.sort_stats("time", "cumulative").print_stats()
- # mark it as showed
- self._profile_stats[i][2] = True
+ self.profiler_collector.show_profiles()
def dump_profiles(self, path):
""" Dump the profile stats into directory `path`
"""
- if not os.path.exists(path):
- os.makedirs(path)
- for id, acc, _ in self._profile_stats:
- stats = acc.value
- if stats:
- p = os.path.join(path, "rdd_%d.pstats" % id)
- stats.dump_stats(p)
- self._profile_stats = []
+ self.profiler_collector.dump_profiles(path)
def _test():
@@ -837,7 +849,7 @@ def _test():
import doctest
import tempfile
globs = globals().copy()
- globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['sc'] = SparkContext('local[4]', 'PythonTest')
globs['tempdir'] = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(globs['tempdir']))
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 64d6202acb27d..f09587f211708 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -26,7 +26,7 @@
import gc
from errno import EINTR, ECHILD, EAGAIN
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
-from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
+from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
from pyspark.worker import main as worker_main
from pyspark.serializers import read_int, write_int
@@ -46,6 +46,9 @@ def worker(sock):
signal.signal(SIGHUP, SIG_DFL)
signal.signal(SIGCHLD, SIG_DFL)
signal.signal(SIGTERM, SIG_DFL)
+ # restore the handler for SIGINT,
+ # it's useful for debugging (show the stacktrace before exit)
+ signal.signal(SIGINT, signal.default_int_handler)
# Read the socket using fdopen instead of socket.makefile() because the latter
# seems to be very slow; note that we need to dup() the file descriptor because
@@ -59,8 +62,7 @@ def worker(sock):
exit_code = compute_real_exit_code(exc.code)
finally:
outfile.flush()
- if exit_code:
- os._exit(exit_code)
+ return exit_code
# Cleanup zombie children
@@ -157,10 +159,13 @@ def handle_sigterm(*args):
outfile.flush()
outfile.close()
while True:
- worker(sock)
- if not reuse:
+ code = worker(sock)
+ if not reuse or code:
# wait for closing
- while sock.recv(1024):
+ try:
+ while sock.recv(1024):
+ pass
+ except Exception:
pass
break
gc.collect()
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 9c70fa5c16d0c..a0a028446d5fd 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -45,7 +45,9 @@ def launch_gateway():
# Don't send ctrl-c / SIGINT to the Java gateway:
def preexec_func():
signal.signal(signal.SIGINT, signal.SIG_IGN)
- proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func)
+ env = dict(os.environ)
+ env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits
+ proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env)
else:
# preexec_fn not supported on Windows
proc = Popen(command, stdout=PIPE, stdin=PIPE)
@@ -109,10 +111,9 @@ def run(self):
java_import(gateway.jvm, "org.apache.spark.api.java.*")
java_import(gateway.jvm, "org.apache.spark.api.python.*")
java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
- java_import(gateway.jvm, "org.apache.spark.sql.SQLContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext")
+ # TODO(davies): move into sql
+ java_import(gateway.jvm, "org.apache.spark.sql.*")
+ java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
java_import(gateway.jvm, "scala.Tuple2")
return gateway
diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py
new file mode 100644
index 0000000000000..47fed80f42e13
--- /dev/null
+++ b/python/pyspark/ml/__init__.py
@@ -0,0 +1,21 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.ml.param import *
+from pyspark.ml.pipeline import *
+
+__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"]
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
new file mode 100644
index 0000000000000..6bd2aa8e47837
--- /dev/null
+++ b/python/pyspark/ml/classification.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
+ HasRegParam
+
+
+__all__ = ['LogisticRegression', 'LogisticRegressionModel']
+
+
+@inherit_doc
+class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
+ HasRegParam):
+ """
+ Logistic regression.
+
+ >>> from pyspark.sql import Row
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \
+ Row(label=1.0, features=Vectors.dense(1.0)), \
+ Row(label=0.0, features=Vectors.sparse(1, [], []))]))
+ >>> lr = LogisticRegression() \
+ .setMaxIter(5) \
+ .setRegParam(0.01)
+ >>> model = lr.fit(dataset)
+ >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
+ >>> print model.transform(test0).head().prediction
+ 0.0
+ >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
+ >>> print model.transform(test1).head().prediction
+ 1.0
+ """
+ _java_class = "org.apache.spark.ml.classification.LogisticRegression"
+
+ def _create_model(self, java_model):
+ return LogisticRegressionModel(java_model)
+
+
+class LogisticRegressionModel(JavaModel):
+ """
+ Model fitted by LogisticRegression.
+ """
+
+
+if __name__ == "__main__":
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import SQLContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ sc = SparkContext("local[2]", "ml.feature tests")
+ sqlCtx = SQLContext(sc)
+ globs['sc'] = sc
+ globs['sqlCtx'] = sqlCtx
+ (failure_count, test_count) = doctest.testmod(
+ globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ if failure_count:
+ exit(-1)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
new file mode 100644
index 0000000000000..e088acd0ca82d
--- /dev/null
+++ b/python/pyspark/ml/feature.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
+from pyspark.ml.util import inherit_doc
+from pyspark.ml.wrapper import JavaTransformer
+
+__all__ = ['Tokenizer', 'HashingTF']
+
+
+@inherit_doc
+class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
+ """
+ A tokenizer that converts the input string to lowercase and then
+ splits it by white spaces.
+
+ >>> from pyspark.sql import Row
+ >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")]))
+ >>> tokenizer = Tokenizer() \
+ .setInputCol("text") \
+ .setOutputCol("words")
+ >>> print tokenizer.transform(dataset).head()
+ Row(text=u'a b c', words=[u'a', u'b', u'c'])
+ >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head()
+ Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
+ """
+
+ _java_class = "org.apache.spark.ml.feature.Tokenizer"
+
+
+@inherit_doc
+class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
+ """
+ Maps a sequence of terms to their term frequencies using the
+ hashing trick.
+
+ >>> from pyspark.sql import Row
+ >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])]))
+ >>> hashingTF = HashingTF() \
+ .setNumFeatures(10) \
+ .setInputCol("words") \
+ .setOutputCol("features")
+ >>> print hashingTF.transform(dataset).head().features
+ (10,[7,8,9],[1.0,1.0,1.0])
+ >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
+ >>> print hashingTF.transform(dataset, params).head().vector
+ (5,[2,3,4],[1.0,1.0,1.0])
+ """
+
+ _java_class = "org.apache.spark.ml.feature.HashingTF"
+
+
+if __name__ == "__main__":
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import SQLContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ sc = SparkContext("local[2]", "ml.feature tests")
+ sqlCtx = SQLContext(sc)
+ globs['sc'] = sc
+ globs['sqlCtx'] = sqlCtx
+ (failure_count, test_count) = doctest.testmod(
+ globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ if failure_count:
+ exit(-1)
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
new file mode 100644
index 0000000000000..5566792cead48
--- /dev/null
+++ b/python/pyspark/ml/param/__init__.py
@@ -0,0 +1,82 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta
+
+from pyspark.ml.util import Identifiable
+
+
+__all__ = ['Param', 'Params']
+
+
+class Param(object):
+ """
+ A param with self-contained documentation and optionally default value.
+ """
+
+ def __init__(self, parent, name, doc, defaultValue=None):
+ if not isinstance(parent, Identifiable):
+ raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
+ self.parent = parent
+ self.name = str(name)
+ self.doc = str(doc)
+ self.defaultValue = defaultValue
+
+ def __str__(self):
+ return str(self.parent) + "-" + self.name
+
+ def __repr__(self):
+ return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
+ (self.parent, self.name, self.doc, self.defaultValue)
+
+
+class Params(Identifiable):
+ """
+ Components that take parameters. This also provides an internal
+ param map to store parameter values attached to the instance.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ super(Params, self).__init__()
+ #: embedded param map
+ self.paramMap = {}
+
+ @property
+ def params(self):
+ """
+ Returns all params. The default implementation uses
+ :py:func:`dir` to get all attributes of type
+ :py:class:`Param`.
+ """
+ return filter(lambda attr: isinstance(attr, Param),
+ [getattr(self, x) for x in dir(self) if x != "params"])
+
+ def _merge_params(self, params):
+ paramMap = self.paramMap.copy()
+ paramMap.update(params)
+ return paramMap
+
+ @staticmethod
+ def _dummy():
+ """
+ Returns a dummy Params instance used as a placeholder to generate docs.
+ """
+ dummy = Params()
+ dummy.uid = "undefined"
+ return dummy
diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py
new file mode 100644
index 0000000000000..5eb81106f116c
--- /dev/null
+++ b/python/pyspark/ml/param/_gen_shared_params.py
@@ -0,0 +1,98 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+header = """#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#"""
+
+
+def _gen_param_code(name, doc, defaultValue):
+ """
+ Generates Python code for a shared param class.
+
+ :param name: param name
+ :param doc: param doc
+ :param defaultValue: string representation of the param
+ :return: code string
+ """
+ # TODO: How to correctly inherit instance attributes?
+ template = '''class Has$Name(Params):
+ """
+ Params with $name.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ $name = Param(Params._dummy(), "$name", "$doc", $defaultValue)
+
+ def __init__(self):
+ super(Has$Name, self).__init__()
+ #: param for $doc
+ self.$name = Param(self, "$name", "$doc", $defaultValue)
+
+ def set$Name(self, value):
+ """
+ Sets the value of :py:attr:`$name`.
+ """
+ self.paramMap[self.$name] = value
+ return self
+
+ def get$Name(self):
+ """
+ Gets the value of $name or its default value.
+ """
+ if self.$name in self.paramMap:
+ return self.paramMap[self.$name]
+ else:
+ return self.$name.defaultValue'''
+
+ upperCamelName = name[0].upper() + name[1:]
+ return template \
+ .replace("$name", name) \
+ .replace("$Name", upperCamelName) \
+ .replace("$doc", doc) \
+ .replace("$defaultValue", defaultValue)
+
+if __name__ == "__main__":
+ print header
+ print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
+ print "from pyspark.ml.param import Param, Params\n\n"
+ shared = [
+ ("maxIter", "max number of iterations", "100"),
+ ("regParam", "regularization constant", "0.1"),
+ ("featuresCol", "features column name", "'features'"),
+ ("labelCol", "label column name", "'label'"),
+ ("predictionCol", "prediction column name", "'prediction'"),
+ ("inputCol", "input column name", "'input'"),
+ ("outputCol", "output column name", "'output'"),
+ ("numFeatures", "number of features", "1 << 18")]
+ code = []
+ for name, doc, defaultValue in shared:
+ code.append(_gen_param_code(name, doc, defaultValue))
+ print "\n\n\n".join(code)
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
new file mode 100644
index 0000000000000..586822f2de423
--- /dev/null
+++ b/python/pyspark/ml/param/shared.py
@@ -0,0 +1,260 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# DO NOT MODIFY. The code is generated by _gen_shared_params.py.
+
+from pyspark.ml.param import Param, Params
+
+
+class HasMaxIter(Params):
+ """
+ Params with maxIter.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100)
+
+ def __init__(self):
+ super(HasMaxIter, self).__init__()
+ #: param for max number of iterations
+ self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
+
+ def setMaxIter(self, value):
+ """
+ Sets the value of :py:attr:`maxIter`.
+ """
+ self.paramMap[self.maxIter] = value
+ return self
+
+ def getMaxIter(self):
+ """
+ Gets the value of maxIter or its default value.
+ """
+ if self.maxIter in self.paramMap:
+ return self.paramMap[self.maxIter]
+ else:
+ return self.maxIter.defaultValue
+
+
+class HasRegParam(Params):
+ """
+ Params with regParam.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1)
+
+ def __init__(self):
+ super(HasRegParam, self).__init__()
+ #: param for regularization constant
+ self.regParam = Param(self, "regParam", "regularization constant", 0.1)
+
+ def setRegParam(self, value):
+ """
+ Sets the value of :py:attr:`regParam`.
+ """
+ self.paramMap[self.regParam] = value
+ return self
+
+ def getRegParam(self):
+ """
+ Gets the value of regParam or its default value.
+ """
+ if self.regParam in self.paramMap:
+ return self.paramMap[self.regParam]
+ else:
+ return self.regParam.defaultValue
+
+
+class HasFeaturesCol(Params):
+ """
+ Params with featuresCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features')
+
+ def __init__(self):
+ super(HasFeaturesCol, self).__init__()
+ #: param for features column name
+ self.featuresCol = Param(self, "featuresCol", "features column name", 'features')
+
+ def setFeaturesCol(self, value):
+ """
+ Sets the value of :py:attr:`featuresCol`.
+ """
+ self.paramMap[self.featuresCol] = value
+ return self
+
+ def getFeaturesCol(self):
+ """
+ Gets the value of featuresCol or its default value.
+ """
+ if self.featuresCol in self.paramMap:
+ return self.paramMap[self.featuresCol]
+ else:
+ return self.featuresCol.defaultValue
+
+
+class HasLabelCol(Params):
+ """
+ Params with labelCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label')
+
+ def __init__(self):
+ super(HasLabelCol, self).__init__()
+ #: param for label column name
+ self.labelCol = Param(self, "labelCol", "label column name", 'label')
+
+ def setLabelCol(self, value):
+ """
+ Sets the value of :py:attr:`labelCol`.
+ """
+ self.paramMap[self.labelCol] = value
+ return self
+
+ def getLabelCol(self):
+ """
+ Gets the value of labelCol or its default value.
+ """
+ if self.labelCol in self.paramMap:
+ return self.paramMap[self.labelCol]
+ else:
+ return self.labelCol.defaultValue
+
+
+class HasPredictionCol(Params):
+ """
+ Params with predictionCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction')
+
+ def __init__(self):
+ super(HasPredictionCol, self).__init__()
+ #: param for prediction column name
+ self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction')
+
+ def setPredictionCol(self, value):
+ """
+ Sets the value of :py:attr:`predictionCol`.
+ """
+ self.paramMap[self.predictionCol] = value
+ return self
+
+ def getPredictionCol(self):
+ """
+ Gets the value of predictionCol or its default value.
+ """
+ if self.predictionCol in self.paramMap:
+ return self.paramMap[self.predictionCol]
+ else:
+ return self.predictionCol.defaultValue
+
+
+class HasInputCol(Params):
+ """
+ Params with inputCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input')
+
+ def __init__(self):
+ super(HasInputCol, self).__init__()
+ #: param for input column name
+ self.inputCol = Param(self, "inputCol", "input column name", 'input')
+
+ def setInputCol(self, value):
+ """
+ Sets the value of :py:attr:`inputCol`.
+ """
+ self.paramMap[self.inputCol] = value
+ return self
+
+ def getInputCol(self):
+ """
+ Gets the value of inputCol or its default value.
+ """
+ if self.inputCol in self.paramMap:
+ return self.paramMap[self.inputCol]
+ else:
+ return self.inputCol.defaultValue
+
+
+class HasOutputCol(Params):
+ """
+ Params with outputCol.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output')
+
+ def __init__(self):
+ super(HasOutputCol, self).__init__()
+ #: param for output column name
+ self.outputCol = Param(self, "outputCol", "output column name", 'output')
+
+ def setOutputCol(self, value):
+ """
+ Sets the value of :py:attr:`outputCol`.
+ """
+ self.paramMap[self.outputCol] = value
+ return self
+
+ def getOutputCol(self):
+ """
+ Gets the value of outputCol or its default value.
+ """
+ if self.outputCol in self.paramMap:
+ return self.paramMap[self.outputCol]
+ else:
+ return self.outputCol.defaultValue
+
+
+class HasNumFeatures(Params):
+ """
+ Params with numFeatures.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18)
+
+ def __init__(self):
+ super(HasNumFeatures, self).__init__()
+ #: param for number of features
+ self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
+
+ def setNumFeatures(self, value):
+ """
+ Sets the value of :py:attr:`numFeatures`.
+ """
+ self.paramMap[self.numFeatures] = value
+ return self
+
+ def getNumFeatures(self):
+ """
+ Gets the value of numFeatures or its default value.
+ """
+ if self.numFeatures in self.paramMap:
+ return self.paramMap[self.numFeatures]
+ else:
+ return self.numFeatures.defaultValue
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
new file mode 100644
index 0000000000000..2d239f8c802a0
--- /dev/null
+++ b/python/pyspark/ml/pipeline.py
@@ -0,0 +1,154 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta, abstractmethod
+
+from pyspark.ml.param import Param, Params
+from pyspark.ml.util import inherit_doc
+
+
+__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']
+
+
+@inherit_doc
+class Estimator(Params):
+ """
+ Abstract class for estimators that fit models to data.
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def fit(self, dataset, params={}):
+ """
+ Fits a model to the input dataset with optional parameters.
+
+ :param dataset: input dataset, which is an instance of
+ :py:class:`pyspark.sql.SchemaRDD`
+ :param params: an optional param map that overwrites embedded
+ params
+ :returns: fitted model
+ """
+ raise NotImplementedError()
+
+
+@inherit_doc
+class Transformer(Params):
+ """
+ Abstract class for transformers that transform one dataset into
+ another.
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def transform(self, dataset, params={}):
+ """
+ Transforms the input dataset with optional parameters.
+
+ :param dataset: input dataset, which is an instance of
+ :py:class:`pyspark.sql.SchemaRDD`
+ :param params: an optional param map that overwrites embedded
+ params
+ :returns: transformed dataset
+ """
+ raise NotImplementedError()
+
+
+@inherit_doc
+class Pipeline(Estimator):
+ """
+ A simple pipeline, which acts as an estimator. A Pipeline consists
+ of a sequence of stages, each of which is either an
+ :py:class:`Estimator` or a :py:class:`Transformer`. When
+ :py:meth:`Pipeline.fit` is called, the stages are executed in
+ order. If a stage is an :py:class:`Estimator`, its
+ :py:meth:`Estimator.fit` method will be called on the input
+ dataset to fit a model. Then the model, which is a transformer,
+ will be used to transform the dataset as the input to the next
+ stage. If a stage is a :py:class:`Transformer`, its
+ :py:meth:`Transformer.transform` method will be called to produce
+ the dataset for the next stage. The fitted model from a
+ :py:class:`Pipeline` is an :py:class:`PipelineModel`, which
+ consists of fitted models and transformers, corresponding to the
+ pipeline stages. If there are no stages, the pipeline acts as an
+ identity transformer.
+ """
+
+ def __init__(self):
+ super(Pipeline, self).__init__()
+ #: Param for pipeline stages.
+ self.stages = Param(self, "stages", "pipeline stages")
+
+ def setStages(self, value):
+ """
+ Set pipeline stages.
+ :param value: a list of transformers or estimators
+ :return: the pipeline instance
+ """
+ self.paramMap[self.stages] = value
+ return self
+
+ def getStages(self):
+ """
+ Get pipeline stages.
+ """
+ if self.stages in self.paramMap:
+ return self.paramMap[self.stages]
+
+ def fit(self, dataset, params={}):
+ paramMap = self._merge_params(params)
+ stages = paramMap[self.stages]
+ for stage in stages:
+ if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
+ raise ValueError(
+ "Cannot recognize a pipeline stage of type %s." % type(stage).__name__)
+ indexOfLastEstimator = -1
+ for i, stage in enumerate(stages):
+ if isinstance(stage, Estimator):
+ indexOfLastEstimator = i
+ transformers = []
+ for i, stage in enumerate(stages):
+ if i <= indexOfLastEstimator:
+ if isinstance(stage, Transformer):
+ transformers.append(stage)
+ dataset = stage.transform(dataset, paramMap)
+ else: # must be an Estimator
+ model = stage.fit(dataset, paramMap)
+ transformers.append(model)
+ if i < indexOfLastEstimator:
+ dataset = model.transform(dataset, paramMap)
+ else:
+ transformers.append(stage)
+ return PipelineModel(transformers)
+
+
+@inherit_doc
+class PipelineModel(Transformer):
+ """
+ Represents a compiled pipeline with transformers and fitted models.
+ """
+
+ def __init__(self, transformers):
+ super(PipelineModel, self).__init__()
+ self.transformers = transformers
+
+ def transform(self, dataset, params={}):
+ paramMap = self._merge_params(params)
+ for t in self.transformers:
+ dataset = t.transform(dataset, paramMap)
+ return dataset
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
new file mode 100644
index 0000000000000..b627c2b4e930b
--- /dev/null
+++ b/python/pyspark/ml/tests.py
@@ -0,0 +1,115 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for Spark ML Python APIs.
+"""
+
+import sys
+
+if sys.version_info[:2] <= (2, 6):
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
+else:
+ import unittest
+
+from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
+from pyspark.sql import DataFrame
+from pyspark.ml.param import Param
+from pyspark.ml.pipeline import Transformer, Estimator, Pipeline
+
+
+class MockDataset(DataFrame):
+
+ def __init__(self):
+ self.index = 0
+
+
+class MockTransformer(Transformer):
+
+ def __init__(self):
+ super(MockTransformer, self).__init__()
+ self.fake = Param(self, "fake", "fake", None)
+ self.dataset_index = None
+ self.fake_param_value = None
+
+ def transform(self, dataset, params={}):
+ self.dataset_index = dataset.index
+ if self.fake in params:
+ self.fake_param_value = params[self.fake]
+ dataset.index += 1
+ return dataset
+
+
+class MockEstimator(Estimator):
+
+ def __init__(self):
+ super(MockEstimator, self).__init__()
+ self.fake = Param(self, "fake", "fake", None)
+ self.dataset_index = None
+ self.fake_param_value = None
+ self.model = None
+
+ def fit(self, dataset, params={}):
+ self.dataset_index = dataset.index
+ if self.fake in params:
+ self.fake_param_value = params[self.fake]
+ model = MockModel()
+ self.model = model
+ return model
+
+
+class MockModel(MockTransformer, Transformer):
+
+ def __init__(self):
+ super(MockModel, self).__init__()
+
+
+class PipelineTests(PySparkTestCase):
+
+ def test_pipeline(self):
+ dataset = MockDataset()
+ estimator0 = MockEstimator()
+ transformer1 = MockTransformer()
+ estimator2 = MockEstimator()
+ transformer3 = MockTransformer()
+ pipeline = Pipeline() \
+ .setStages([estimator0, transformer1, estimator2, transformer3])
+ pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1})
+ self.assertEqual(0, estimator0.dataset_index)
+ self.assertEqual(0, estimator0.fake_param_value)
+ model0 = estimator0.model
+ self.assertEqual(0, model0.dataset_index)
+ self.assertEqual(1, transformer1.dataset_index)
+ self.assertEqual(1, transformer1.fake_param_value)
+ self.assertEqual(2, estimator2.dataset_index)
+ model2 = estimator2.model
+ self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should "
+ "not be called during fit.")
+ dataset = pipeline_model.transform(dataset)
+ self.assertEqual(2, model0.dataset_index)
+ self.assertEqual(3, transformer1.dataset_index)
+ self.assertEqual(4, model2.dataset_index)
+ self.assertEqual(5, transformer3.dataset_index)
+ self.assertEqual(6, dataset.index)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
new file mode 100644
index 0000000000000..b1caa84b6306a
--- /dev/null
+++ b/python/pyspark/ml/util.py
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import uuid
+
+
+def inherit_doc(cls):
+ for name, func in vars(cls).items():
+ # only inherit docstring for public functions
+ if name.startswith("_"):
+ continue
+ if not func.__doc__:
+ for parent in cls.__bases__:
+ parent_func = getattr(parent, name, None)
+ if parent_func and getattr(parent_func, "__doc__", None):
+ func.__doc__ = parent_func.__doc__
+ break
+ return cls
+
+
+class Identifiable(object):
+ """
+ Object with a unique ID.
+ """
+
+ def __init__(self):
+ #: A unique id for the object. The default implementation
+ #: concatenates the class name, "-", and 8 random hex chars.
+ self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]
+
+ def __repr__(self):
+ return self.uid
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
new file mode 100644
index 0000000000000..9e12ddc3d9b8f
--- /dev/null
+++ b/python/pyspark/ml/wrapper.py
@@ -0,0 +1,149 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABCMeta
+
+from pyspark import SparkContext
+from pyspark.sql import DataFrame
+from pyspark.ml.param import Params
+from pyspark.ml.pipeline import Estimator, Transformer
+from pyspark.ml.util import inherit_doc
+
+
+def _jvm():
+ """
+ Returns the JVM view associated with SparkContext. Must be called
+ after SparkContext is initialized.
+ """
+ jvm = SparkContext._jvm
+ if jvm:
+ return jvm
+ else:
+ raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
+
+
+@inherit_doc
+class JavaWrapper(Params):
+ """
+ Utility class to help create wrapper classes from Java/Scala
+ implementations of pipeline components.
+ """
+
+ __metaclass__ = ABCMeta
+
+ #: Fully-qualified class name of the wrapped Java component.
+ _java_class = None
+
+ def _java_obj(self):
+ """
+ Returns or creates a Java object.
+ """
+ java_obj = _jvm()
+ for name in self._java_class.split("."):
+ java_obj = getattr(java_obj, name)
+ return java_obj()
+
+ def _transfer_params_to_java(self, params, java_obj):
+ """
+ Transforms the embedded params and additional params to the
+ input Java object.
+ :param params: additional params (overwriting embedded values)
+ :param java_obj: Java object to receive the params
+ """
+ paramMap = self._merge_params(params)
+ for param in self.params:
+ if param in paramMap:
+ java_obj.set(param.name, paramMap[param])
+
+ def _empty_java_param_map(self):
+ """
+ Returns an empty Java ParamMap reference.
+ """
+ return _jvm().org.apache.spark.ml.param.ParamMap()
+
+ def _create_java_param_map(self, params, java_obj):
+ paramMap = self._empty_java_param_map()
+ for param, value in params.items():
+ if param.parent is self:
+ paramMap.put(java_obj.getParam(param.name), value)
+ return paramMap
+
+
+@inherit_doc
+class JavaEstimator(Estimator, JavaWrapper):
+ """
+ Base class for :py:class:`Estimator`s that wrap Java/Scala
+ implementations.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def _create_model(self, java_model):
+ """
+ Creates a model from the input Java model reference.
+ """
+ return JavaModel(java_model)
+
+ def _fit_java(self, dataset, params={}):
+ """
+ Fits a Java model to the input dataset.
+ :param dataset: input dataset, which is an instance of
+ :py:class:`pyspark.sql.SchemaRDD`
+ :param params: additional params (overwriting embedded values)
+ :return: fitted Java model
+ """
+ java_obj = self._java_obj()
+ self._transfer_params_to_java(params, java_obj)
+ return java_obj.fit(dataset._jdf, self._empty_java_param_map())
+
+ def fit(self, dataset, params={}):
+ java_model = self._fit_java(dataset, params)
+ return self._create_model(java_model)
+
+
+@inherit_doc
+class JavaTransformer(Transformer, JavaWrapper):
+ """
+ Base class for :py:class:`Transformer`s that wrap Java/Scala
+ implementations.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def transform(self, dataset, params={}):
+ java_obj = self._java_obj()
+ self._transfer_params_to_java({}, java_obj)
+ java_param_map = self._create_java_param_map(params, java_obj)
+ return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
+ dataset.sql_ctx)
+
+
+@inherit_doc
+class JavaModel(JavaTransformer):
+ """
+ Base class for :py:class:`Model`s that wrap Java/Scala
+ implementations.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self, java_model):
+ super(JavaTransformer, self).__init__()
+ self._java_model = java_model
+
+ def _java_obj(self):
+ return self._java_model
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
index 4149f54931d1f..c3217620e3c4e 100644
--- a/python/pyspark/mllib/__init__.py
+++ b/python/pyspark/mllib/__init__.py
@@ -24,3 +24,12 @@
import numpy
if numpy.version.version < '1.4':
raise Exception("MLlib requires NumPy 1.4+")
+
+__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random',
+ 'recommendation', 'regression', 'stat', 'tree', 'util']
+
+import sys
+import rand as random
+random.__name__ = 'random'
+random.RandomRDDs.__module__ = __name__ + '.random'
+sys.modules[__name__ + '.random'] = random
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index e295c9d0954d9..00e2e76711e84 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -20,99 +20,200 @@
import numpy
from numpy import array
-from pyspark import SparkContext, PickleSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
+from pyspark import RDD
+from pyspark.mllib.common import callMLlibFunc
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
-__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel',
- 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
+__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
+ 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
-class LogisticRegressionModel(LinearModel):
+class LinearBinaryClassificationModel(LinearModel):
+ """
+ Represents a linear binary classification model that predicts to whether an
+ example is positive (1.0) or negative (0.0).
+ """
+ def __init__(self, weights, intercept):
+ super(LinearBinaryClassificationModel, self).__init__(weights, intercept)
+ self._threshold = None
+
+ def setThreshold(self, value):
+ """
+ .. note:: Experimental
+
+ Sets the threshold that separates positive predictions from negative
+ predictions. An example with prediction score greater than or equal
+ to this threshold is identified as an positive, and negative otherwise.
+ """
+ self._threshold = value
+
+ def clearThreshold(self):
+ """
+ .. note:: Experimental
+
+ Clears the threshold so that `predict` will output raw prediction scores.
+ """
+ self._threshold = None
+
+ def predict(self, test):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ raise NotImplementedError
+
+
+class LogisticRegressionModel(LinearBinaryClassificationModel):
"""A linear binary classification model derived from logistic regression.
>>> data = [
- ... LabeledPoint(0.0, [0.0]),
- ... LabeledPoint(1.0, [1.0]),
- ... LabeledPoint(1.0, [2.0]),
- ... LabeledPoint(1.0, [3.0])
+ ... LabeledPoint(0.0, [0.0, 1.0]),
+ ... LabeledPoint(1.0, [1.0, 0.0]),
... ]
>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data))
- >>> lrm.predict(array([1.0])) > 0
- True
- >>> lrm.predict(array([0.0])) <= 0
- True
+ >>> lrm.predict([1.0, 0.0])
+ 1
+ >>> lrm.predict([0.0, 1.0])
+ 0
+ >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect()
+ [1, 0]
+ >>> lrm.clearThreshold()
+ >>> lrm.predict([0.0, 1.0])
+ 0.123...
+
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
- ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
+ ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data))
- >>> lrm.predict(array([0.0, 1.0])) > 0
- True
- >>> lrm.predict(array([0.0, 0.0])) <= 0
- True
- >>> lrm.predict(SparseVector(2, {1: 1.0})) > 0
- True
- >>> lrm.predict(SparseVector(2, {1: 0.0})) <= 0
- True
+ >>> lrm.predict(array([0.0, 1.0]))
+ 1
+ >>> lrm.predict(array([1.0, 0.0]))
+ 0
+ >>> lrm.predict(SparseVector(2, {1: 1.0}))
+ 1
+ >>> lrm.predict(SparseVector(2, {0: 1.0}))
+ 0
"""
+ def __init__(self, weights, intercept):
+ super(LogisticRegressionModel, self).__init__(weights, intercept)
+ self._threshold = 0.5
def predict(self, x):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ if isinstance(x, RDD):
+ return x.map(lambda v: self.predict(v))
+
+ x = _convert_to_vector(x)
margin = self.weights.dot(x) + self._intercept
if margin > 0:
prob = 1 / (1 + exp(-margin))
else:
exp_margin = exp(margin)
prob = exp_margin / (1 + exp_margin)
- return 1 if prob > 0.5 else 0
+ if self._threshold is None:
+ return prob
+ else:
+ return 1 if prob > self._threshold else 0
class LogisticRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
- initialWeights=None, regParam=1.0, regType="none", intercept=False):
+ initialWeights=None, regParam=0.01, regType="l2", intercept=False):
"""
Train a logistic regression model on the given data.
- :param data: The training data.
+ :param data: The training data, an RDD of LabeledPoint.
:param iterations: The number of iterations (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
:param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
:param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter (default: 1.0).
+ :param regParam: The regularizer parameter (default: 0.01).
:param regType: The type of regularizer used for training
our model.
:Allowed values:
- - "l1" for using L1Updater
- - "l2" for using SquaredL2Updater
- - "none" for no regularizer
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization
- (default: "none")
+ (default: "l2")
- @param intercept: Boolean parameter which indicates the use
+ :param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
- sc = data.context
+ def train(rdd, i):
+ return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations),
+ float(step), float(miniBatchFraction), i, float(regParam), regType,
+ bool(intercept))
- def train(jdata, i):
- return sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(
- jdata, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
+ return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
- return _regression_train_wrapper(sc, train, LogisticRegressionModel, data,
- initialWeights)
+class LogisticRegressionWithLBFGS(object):
-class SVMModel(LinearModel):
+ @classmethod
+ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2",
+ intercept=False, corrections=10, tolerance=1e-4):
+ """
+ Train a logistic regression model on the given data.
+
+ :param data: The training data, an RDD of LabeledPoint.
+ :param iterations: The number of iterations (default: 100).
+ :param initialWeights: The initial weights (default: None).
+ :param regParam: The regularizer parameter (default: 0.01).
+ :param regType: The type of regularizer used for training
+ our model.
+
+ :Allowed values:
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization
+
+ (default: "l2")
+
+ :param intercept: Boolean parameter which indicates the use
+ or not of the augmented representation for
+ training data (i.e. whether bias features
+ are activated or not).
+ :param corrections: The number of corrections used in the LBFGS
+ update (default: 10).
+ :param tolerance: The convergence tolerance of iterations for
+ L-BFGS (default: 1e-4).
+
+ >>> data = [
+ ... LabeledPoint(0.0, [0.0, 1.0]),
+ ... LabeledPoint(1.0, [1.0, 0.0]),
+ ... ]
+ >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data))
+ >>> lrm.predict([1.0, 0.0])
+ 1
+ >>> lrm.predict([0.0, 1.0])
+ 0
+ """
+ def train(rdd, i):
+ return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i,
+ float(regParam), str(regType), bool(intercept), int(corrections),
+ float(tolerance))
+
+ return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
+
+
+class SVMModel(LinearBinaryClassificationModel):
"""A support vector machine.
@@ -123,8 +224,14 @@ class SVMModel(LinearModel):
... LabeledPoint(1.0, [3.0])
... ]
>>> svm = SVMWithSGD.train(sc.parallelize(data))
- >>> svm.predict(array([1.0])) > 0
- True
+ >>> svm.predict([1.0])
+ 1
+ >>> svm.predict(sc.parallelize([[1.0]])).collect()
+ [1]
+ >>> svm.clearThreshold()
+ >>> svm.predict(array([1.0]))
+ 1.25...
+
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: -1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
@@ -132,30 +239,44 @@ class SVMModel(LinearModel):
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>> svm = SVMWithSGD.train(sc.parallelize(sparse_data))
- >>> svm.predict(SparseVector(2, {1: 1.0})) > 0
- True
- >>> svm.predict(SparseVector(2, {0: -1.0})) <= 0
- True
+ >>> svm.predict(SparseVector(2, {1: 1.0}))
+ 1
+ >>> svm.predict(SparseVector(2, {0: -1.0}))
+ 0
"""
+ def __init__(self, weights, intercept):
+ super(SVMModel, self).__init__(weights, intercept)
+ self._threshold = 0.0
def predict(self, x):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ if isinstance(x, RDD):
+ return x.map(lambda v: self.predict(v))
+
+ x = _convert_to_vector(x)
margin = self.weights.dot(x) + self.intercept
- return 1 if margin >= 0 else 0
+ if self._threshold is None:
+ return margin
+ else:
+ return 1 if margin > self._threshold else 0
class SVMWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, regParam=1.0,
- miniBatchFraction=1.0, initialWeights=None, regType="none", intercept=False):
+ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
+ miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False):
"""
Train a support vector machine on the given data.
- :param data: The training data.
+ :param data: The training data, an RDD of LabeledPoint.
:param iterations: The number of iterations (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
- :param regParam: The regularizer parameter (default: 1.0).
+ :param regParam: The regularizer parameter (default: 0.01).
:param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
:param initialWeights: The initial weights (default: None).
@@ -163,24 +284,23 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0,
our model.
:Allowed values:
- - "l1" for using L1Updater
- - "l2" for using SquaredL2Updater,
- - "none" for no regularizer.
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization
- (default: "none")
+ (default: "l2")
- @param intercept: Boolean parameter which indicates the use
+ :param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
- sc = data.context
-
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept)
+ def train(rdd, i):
+ return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step),
+ float(regParam), float(miniBatchFraction), i, regType,
+ bool(intercept))
- return _regression_train_wrapper(sc, train, SVMModel, data, initialWeights)
+ return _regression_train_wrapper(train, SVMModel, data, initialWeights)
class NaiveBayesModel(object):
@@ -202,6 +322,8 @@ class NaiveBayesModel(object):
0.0
>>> model.predict(array([1.0, 0.0]))
1.0
+ >>> model.predict(sc.parallelize([[1.0, 0.0]])).collect()
+ [1.0]
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {1: 0.0})),
... LabeledPoint(0.0, SparseVector(2, {1: 1.0})),
@@ -220,7 +342,9 @@ def __init__(self, labels, pi, theta):
self.theta = theta
def predict(self, x):
- """Return the most likely class for a data vector x"""
+ """Return the most likely class for a data vector or an RDD of vectors"""
+ if isinstance(x, RDD):
+ return x.map(lambda v: self.predict(v))
x = _convert_to_vector(x)
return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))]
@@ -238,20 +362,21 @@ def train(cls, data, lambda_=1.0):
classification. By making every vector a 0-1 vector, it can also be
used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
- :param data: RDD of NumPy vectors, one per element, where the first
- coordinate is the label and the rest is the feature vector
- (e.g. a count vector).
+ :param data: RDD of LabeledPoint.
:param lambda_: The smoothing parameter
"""
- sc = data.context
- jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(_to_java_object_rdd(data), lambda_)
- labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist)))
+ first = data.first()
+ if not isinstance(first, LabeledPoint):
+ raise ValueError("`data` should be an RDD of LabeledPoint")
+ labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_)
return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta))
def _test():
import doctest
- globs = globals().copy()
+ from pyspark import SparkContext
+ import pyspark.mllib.classification
+ globs = pyspark.mllib.classification.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 5ee7997104d21..f6b97abb1723c 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -15,19 +15,22 @@
# limitations under the License.
#
+from numpy import array
+
+from pyspark import RDD
from pyspark import SparkContext
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, callJavaFunc
+from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
+from pyspark.mllib.stat.distribution import MultivariateGaussian
-__all__ = ['KMeansModel', 'KMeans']
+__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']
class KMeansModel(object):
"""A clustering model derived from the k-means method.
- >>> from numpy import array
- >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2)
+ >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4, 2)
>>> model = KMeans.train(
... sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random")
>>> model.predict(array([0.0, 0.0])) == model.predict(array([1.0, 1.0]))
@@ -78,19 +81,95 @@ def predict(self, x):
class KMeans(object):
@classmethod
- def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
+ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None):
"""Train a k-means clustering model."""
- sc = rdd.context
- ser = PickleSerializer()
- # cache serialized data to avoid objects over head in JVM
- cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache()
- model = sc._jvm.PythonMLLibAPI().trainKMeansModel(
- _to_java_object_rdd(cached), k, maxIterations, runs, initializationMode)
- bytes = sc._jvm.SerDe.dumps(model.clusterCenters())
- centers = ser.loads(str(bytes))
+ model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
+ runs, initializationMode, seed)
+ centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])
+class GaussianMixtureModel(object):
+
+ """A clustering model derived from the Gaussian Mixture Model method.
+
+ >>> clusterdata_1 = sc.parallelize(array([-0.1,-0.05,-0.01,-0.1,
+ ... 0.9,0.8,0.75,0.935,
+ ... -0.83,-0.68,-0.91,-0.76 ]).reshape(6, 2))
+ >>> model = GaussianMixture.train(clusterdata_1, 3, convergenceTol=0.0001,
+ ... maxIterations=50, seed=10)
+ >>> labels = model.predict(clusterdata_1).collect()
+ >>> labels[0]==labels[1]
+ False
+ >>> labels[1]==labels[2]
+ True
+ >>> labels[4]==labels[5]
+ True
+ >>> clusterdata_2 = sc.parallelize(array([-5.1971, -2.5359, -3.8220,
+ ... -5.2211, -5.0602, 4.7118,
+ ... 6.8989, 3.4592, 4.6322,
+ ... 5.7048, 4.6567, 5.5026,
+ ... 4.5605, 5.2043, 6.2734]).reshape(5, 3))
+ >>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
+ ... maxIterations=150, seed=10)
+ >>> labels = model.predict(clusterdata_2).collect()
+ >>> labels[0]==labels[1]==labels[2]
+ True
+ >>> labels[3]==labels[4]
+ True
+ """
+
+ def __init__(self, weights, gaussians):
+ self.weights = weights
+ self.gaussians = gaussians
+ self.k = len(self.weights)
+
+ def predict(self, x):
+ """
+ Find the cluster to which the points in 'x' has maximum membership
+ in this model.
+
+ :param x: RDD of data points.
+ :return: cluster_labels. RDD of cluster labels.
+ """
+ if isinstance(x, RDD):
+ cluster_labels = self.predictSoft(x).map(lambda z: z.index(max(z)))
+ return cluster_labels
+
+ def predictSoft(self, x):
+ """
+ Find the membership of each point in 'x' to all mixture components.
+
+ :param x: RDD of data points.
+ :return: membership_matrix. RDD of array of double values.
+ """
+ if isinstance(x, RDD):
+ means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
+ membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
+ self.weights, means, sigmas)
+ return membership_matrix
+
+
+class GaussianMixture(object):
+ """
+ Estimate model parameters with the expectation-maximization algorithm.
+
+ :param data: RDD of data points
+ :param k: Number of components
+ :param convergenceTol: Threshold value to check the convergence criteria. Defaults to 1e-3
+ :param maxIterations: Number of iterations. Default to 100
+ :param seed: Random Seed
+ """
+ @classmethod
+ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None):
+ """Train a Gaussian Mixture clustering model."""
+ weight, mu, sigma = callMLlibFunc("trainGaussianMixture",
+ rdd.map(_convert_to_vector), k,
+ convergenceTol, maxIterations, seed)
+ mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)]
+ return GaussianMixtureModel(weight, mvg_obj)
+
+
def _test():
import doctest
globs = globals().copy()
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
new file mode 100644
index 0000000000000..3c5ee66cd8b64
--- /dev/null
+++ b/python/pyspark/mllib/common.py
@@ -0,0 +1,136 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import py4j.protocol
+from py4j.protocol import Py4JJavaError
+from py4j.java_gateway import JavaObject
+from py4j.java_collections import ListConverter, JavaArray, JavaList
+
+from pyspark import RDD, SparkContext
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+
+
+# Hack for support float('inf') in Py4j
+_old_smart_decode = py4j.protocol.smart_decode
+
+_float_str_mapping = {
+ 'nan': 'NaN',
+ 'inf': 'Infinity',
+ '-inf': '-Infinity',
+}
+
+
+def _new_smart_decode(obj):
+ if isinstance(obj, float):
+ s = unicode(obj)
+ return _float_str_mapping.get(s, s)
+ return _old_smart_decode(obj)
+
+py4j.protocol.smart_decode = _new_smart_decode
+
+
+_picklable_classes = [
+ 'LinkedList',
+ 'SparseVector',
+ 'DenseVector',
+ 'DenseMatrix',
+ 'Rating',
+ 'LabeledPoint',
+]
+
+
+# this will call the MLlib version of pythonToJava()
+def _to_java_object_rdd(rdd):
+ """ Return an JavaRDD of Object by unpickling
+
+ It will convert each Python object into Java object by Pyrolite, whenever the
+ RDD is serialized in batch or not.
+ """
+ rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
+ return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
+
+
+def _py2java(sc, obj):
+ """ Convert Python object into Java """
+ if isinstance(obj, RDD):
+ obj = _to_java_object_rdd(obj)
+ elif isinstance(obj, SparkContext):
+ obj = obj._jsc
+ elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)):
+ obj = ListConverter().convert(obj, sc._gateway._gateway_client)
+ elif isinstance(obj, JavaObject):
+ pass
+ elif isinstance(obj, (int, long, float, bool, basestring)):
+ pass
+ else:
+ bytes = bytearray(PickleSerializer().dumps(obj))
+ obj = sc._jvm.SerDe.loads(bytes)
+ return obj
+
+
+def _java2py(sc, r):
+ if isinstance(r, JavaObject):
+ clsName = r.getClass().getSimpleName()
+ # convert RDD into JavaRDD
+ if clsName != 'JavaRDD' and clsName.endswith("RDD"):
+ r = r.toJavaRDD()
+ clsName = 'JavaRDD'
+
+ if clsName == 'JavaRDD':
+ jrdd = sc._jvm.SerDe.javaToPython(r)
+ return RDD(jrdd, sc)
+
+ if clsName in _picklable_classes:
+ r = sc._jvm.SerDe.dumps(r)
+ elif isinstance(r, (JavaArray, JavaList)):
+ try:
+ r = sc._jvm.SerDe.dumps(r)
+ except Py4JJavaError:
+ pass # not pickable
+
+ if isinstance(r, bytearray):
+ r = PickleSerializer().loads(str(r))
+ return r
+
+
+def callJavaFunc(sc, func, *args):
+ """ Call Java Function """
+ args = [_py2java(sc, a) for a in args]
+ return _java2py(sc, func(*args))
+
+
+def callMLlibFunc(name, *args):
+ """ Call API in PythonMLLibAPI """
+ sc = SparkContext._active_spark_context
+ api = getattr(sc._jvm.PythonMLLibAPI(), name)
+ return callJavaFunc(sc, api, *args)
+
+
+class JavaModelWrapper(object):
+ """
+ Wrapper for the model in JVM
+ """
+ def __init__(self, java_model):
+ self._sc = SparkContext._active_spark_context
+ self._java_model = java_model
+
+ def __del__(self):
+ self._sc._gateway.detach(self._java_model)
+
+ def call(self, name, *a):
+ """Call method of java_model"""
+ return callJavaFunc(self._sc, getattr(self._java_model, name), *a)
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index b5a3f22c6907e..10df6288065b8 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -18,59 +18,302 @@
"""
Python package for feature in MLlib.
"""
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd
+from __future__ import absolute_import
-__all__ = ['Word2Vec', 'Word2VecModel']
+import sys
+import warnings
+import random
+from py4j.protocol import Py4JJavaError
-class Word2VecModel(object):
+from pyspark import RDD, SparkContext
+from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
+from pyspark.mllib.linalg import Vectors, Vector, _convert_to_vector
+
+__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
+ 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
+
+
+class VectorTransformer(object):
"""
- class for Word2Vec model
+ .. note:: DeveloperApi
+
+ Base class for transformation of a vector or RDD of vector
"""
- def __init__(self, sc, java_model):
+ def transform(self, vector):
"""
- :param sc: Spark context
- :param java_model: Handle to Java model object
+ Applies transformation on a vector.
+
+ :param vector: vector to be transformed.
"""
- self._sc = sc
- self._java_model = java_model
+ raise NotImplementedError
- def __del__(self):
- self._sc._gateway.detach(self._java_model)
- def transform(self, word):
+class Normalizer(VectorTransformer):
+ """
+ .. note:: Experimental
+
+ Normalizes samples individually to unit L\ :sup:`p`\ norm
+
+ For any 1 <= `p` < float('inf'), normalizes samples using
+ sum(abs(vector) :sup:`p`) :sup:`(1/p)` as norm.
+
+ For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization.
+
+ >>> v = Vectors.dense(range(3))
+ >>> nor = Normalizer(1)
+ >>> nor.transform(v)
+ DenseVector([0.0, 0.3333, 0.6667])
+
+ >>> rdd = sc.parallelize([v])
+ >>> nor.transform(rdd).collect()
+ [DenseVector([0.0, 0.3333, 0.6667])]
+
+ >>> nor2 = Normalizer(float("inf"))
+ >>> nor2.transform(v)
+ DenseVector([0.0, 0.5, 1.0])
+ """
+ def __init__(self, p=2.0):
"""
- :param word: a word
- :return: vector representation of word
+ :param p: Normalization in L^p^ space, p = 2 by default.
+ """
+ assert p >= 1.0, "p should be greater than 1.0"
+ self.p = float(p)
+
+ def transform(self, vector):
+ """
+ Applies unit length normalization on a vector.
+
+ :param vector: vector or RDD of vector to be normalized.
+ :return: normalized vector. If the norm of the input is zero, it
+ will return the input vector.
+ """
+ sc = SparkContext._active_spark_context
+ assert sc is not None, "SparkContext should be initialized first"
+ if isinstance(vector, RDD):
+ vector = vector.map(_convert_to_vector)
+ else:
+ vector = _convert_to_vector(vector)
+ return callMLlibFunc("normalizeVector", self.p, vector)
+
+
+class JavaVectorTransformer(JavaModelWrapper, VectorTransformer):
+ """
+ Wrapper for the model in JVM
+ """
+
+ def transform(self, vector):
+ if isinstance(vector, RDD):
+ vector = vector.map(_convert_to_vector)
+ else:
+ vector = _convert_to_vector(vector)
+ return self.call("transform", vector)
+
+
+class StandardScalerModel(JavaVectorTransformer):
+ """
+ .. note:: Experimental
+
+ Represents a StandardScaler model that can transform vectors.
+ """
+ def transform(self, vector):
+ """
+ Applies standardization transformation on a vector.
+
+ :param vector: Vector or RDD of Vector to be standardized.
+ :return: Standardized vector. If the variance of a column is zero,
+ it will return default `0.0` for the column with zero variance.
+ """
+ return JavaVectorTransformer.transform(self, vector)
+
+
+class StandardScaler(object):
+ """
+ .. note:: Experimental
+
+ Standardizes features by removing the mean and scaling to unit
+ variance using column summary statistics on the samples in the
+ training set.
+
+ >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])]
+ >>> dataset = sc.parallelize(vs)
+ >>> standardizer = StandardScaler(True, True)
+ >>> model = standardizer.fit(dataset)
+ >>> result = model.transform(dataset)
+ >>> for r in result.collect(): r
+ DenseVector([-0.7071, 0.7071, -0.7071])
+ DenseVector([0.7071, -0.7071, 0.7071])
+ """
+ def __init__(self, withMean=False, withStd=True):
+ """
+ :param withMean: False by default. Centers the data with mean
+ before scaling. It will build a dense output, so this
+ does not work on sparse input and will raise an exception.
+ :param withStd: True by default. Scales the data to unit standard
+ deviation.
+ """
+ if not (withMean or withStd):
+ warnings.warn("Both withMean and withStd are false. The model does nothing.")
+ self.withMean = withMean
+ self.withStd = withStd
+
+ def fit(self, dataset):
+ """
+ Computes the mean and variance and stores as a model to be used for later scaling.
+
+ :param data: The data used to compute the mean and variance to build
+ the transformation model.
+ :return: a StandardScalarModel
+ """
+ dataset = dataset.map(_convert_to_vector)
+ jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset)
+ return StandardScalerModel(jmodel)
+
+
+class HashingTF(object):
+ """
+ .. note:: Experimental
+ Maps a sequence of terms to their term frequencies using the hashing trick.
+
+ Note: the terms must be hashable (can not be dict/set/list...).
+
+ >>> htf = HashingTF(100)
+ >>> doc = "a a b b c d".split(" ")
+ >>> htf.transform(doc)
+ SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0})
+ """
+ def __init__(self, numFeatures=1 << 20):
+ """
+ :param numFeatures: number of features (default: 2^20)
+ """
+ self.numFeatures = numFeatures
+
+ def indexOf(self, term):
+ """ Returns the index of the input term. """
+ return hash(term) % self.numFeatures
+
+ def transform(self, document):
+ """
+ Transforms the input document (list of terms) to term frequency vectors,
+ or transform the RDD of document to RDD of term frequency vectors.
+ """
+ if isinstance(document, RDD):
+ return document.map(self.transform)
+
+ freq = {}
+ for term in document:
+ i = self.indexOf(term)
+ freq[i] = freq.get(i, 0) + 1.0
+ return Vectors.sparse(self.numFeatures, freq.items())
+
+
+class IDFModel(JavaVectorTransformer):
+ """
+ Represents an IDF model that can transform term frequency vectors.
+ """
+ def transform(self, x):
+ """
+ Transforms term frequency (TF) vectors to TF-IDF vectors.
+
+ If `minDocFreq` was set for the IDF calculation,
+ the terms which occur in fewer than `minDocFreq`
+ documents will have an entry of 0.
+
+ :param x: an RDD of term frequency vectors or a term frequency vector
+ :return: an RDD of TF-IDF vectors or a TF-IDF vector
+ """
+ if isinstance(x, RDD):
+ return JavaVectorTransformer.transform(self, x)
+
+ x = _convert_to_vector(x)
+ return JavaVectorTransformer.transform(self, x)
+
+
+class IDF(object):
+ """
+ .. note:: Experimental
+
+ Inverse document frequency (IDF).
+
+ The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`,
+ where `m` is the total number of documents and `d(t)` is the number
+ of documents that contain term `t`.
+
+ This implementation supports filtering out terms which do not appear
+ in a minimum number of documents (controlled by the variable `minDocFreq`).
+ For terms that are not in at least `minDocFreq` documents, the IDF is
+ found as 0, resulting in TF-IDFs of 0.
+
+ >>> n = 4
+ >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)),
+ ... Vectors.dense([0.0, 1.0, 2.0, 3.0]),
+ ... Vectors.sparse(n, [1], [1.0])]
+ >>> data = sc.parallelize(freqs)
+ >>> idf = IDF()
+ >>> model = idf.fit(data)
+ >>> tfidf = model.transform(data)
+ >>> for r in tfidf.collect(): r
+ SparseVector(4, {1: 0.0, 3: 0.5754})
+ DenseVector([0.0, 0.0, 1.3863, 0.863])
+ SparseVector(4, {1: 0.0})
+ >>> model.transform(Vectors.dense([0.0, 1.0, 2.0, 3.0]))
+ DenseVector([0.0, 0.0, 1.3863, 0.863])
+ >>> model.transform([0.0, 1.0, 2.0, 3.0])
+ DenseVector([0.0, 0.0, 1.3863, 0.863])
+ >>> model.transform(Vectors.sparse(n, (1, 3), (1.0, 2.0)))
+ SparseVector(4, {1: 0.0, 3: 0.5754})
+ """
+ def __init__(self, minDocFreq=0):
+ """
+ :param minDocFreq: minimum of documents in which a term
+ should appear for filtering
+ """
+ self.minDocFreq = minDocFreq
+
+ def fit(self, dataset):
+ """
+ Computes the inverse document frequency.
+
+ :param dataset: an RDD of term frequency vectors
+ """
+ if not isinstance(dataset, RDD):
+ raise TypeError("dataset should be an RDD of term frequency vectors")
+ jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset.map(_convert_to_vector))
+ return IDFModel(jmodel)
+
+
+class Word2VecModel(JavaVectorTransformer):
+ """
+ class for Word2Vec model
+ """
+ def transform(self, word):
+ """
Transforms a word to its vector representation
Note: local use only
+
+ :param word: a word
+ :return: vector representation of word(s)
"""
- # TODO: make transform usable in RDD operations from python side
- result = self._java_model.transform(word)
- return PickleSerializer().loads(str(self._sc._jvm.SerDe.dumps(result)))
+ try:
+ return self.call("transform", word)
+ except Py4JJavaError:
+ raise ValueError("%s not found" % word)
- def findSynonyms(self, x, num):
+ def findSynonyms(self, word, num):
"""
- :param x: a word or a vector representation of word
+ Find synonyms of a word
+
+ :param word: a word or a vector representation of word
:param num: number of synonyms to find
:return: array of (word, cosineSimilarity)
- Find synonyms of a word
-
Note: local use only
"""
- # TODO: make findSynonyms usable in RDD operations from python side
- ser = PickleSerializer()
- if type(x) == str:
- jlist = self._java_model.findSynonyms(x, num)
- else:
- bytes = bytearray(ser.dumps(_convert_to_vector(x)))
- vec = self._sc._jvm.SerDe.loads(bytes)
- jlist = self._java_model.findSynonyms(vec, num)
- words, similarity = ser.loads(str(self._sc._jvm.SerDe.dumps(jlist)))
+ if not isinstance(word, basestring):
+ word = _convert_to_vector(word)
+ words, similarity = self.call("findSynonyms", word, num)
return zip(words, similarity)
@@ -85,6 +328,7 @@ class Word2Vec(object):
We used skip-gram model in our implementation and hierarchical softmax
method to train the model. The variable names in the implementation
matches the original C implementation.
+
For original C implementation, see https://code.google.com/p/word2vec/
For research papers, see
Efficient Estimation of Word Representations in Vector Space
@@ -95,23 +339,14 @@ class Word2Vec(object):
>>> localDoc = [sentence, sentence]
>>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" "))
>>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
+
>>> syms = model.findSynonyms("a", 2)
- >>> str(syms[0][0])
- 'b'
- >>> str(syms[1][0])
- 'c'
- >>> len(syms)
- 2
+ >>> [s[0] for s in syms]
+ [u'b', u'c']
>>> vec = model.transform("a")
- >>> len(vec)
- 10
>>> syms = model.findSynonyms(vec, 2)
- >>> str(syms[0][0])
- 'b'
- >>> str(syms[1][0])
- 'c'
- >>> len(syms)
- 2
+ >>> [s[0] for s in syms]
+ [u'b', u'c']
"""
def __init__(self):
"""
@@ -121,7 +356,7 @@ def __init__(self):
self.learningRate = 0.025
self.numPartitions = 1
self.numIterations = 1
- self.seed = 42L
+ self.seed = random.randint(0, sys.maxint)
def setVectorSize(self, vectorSize):
"""
@@ -163,21 +398,15 @@ def fit(self, data):
"""
Computes the vector representation of each word in vocabulary.
- :param data: training data. RDD of subtype of Iterable[String]
- :return: python Word2VecModel instance
+ :param data: training data. RDD of list of string
+ :return: Word2VecModel instance
"""
- sc = data.context
- ser = PickleSerializer()
- vectorSize = self.vectorSize
- learningRate = self.learningRate
- numPartitions = self.numPartitions
- numIterations = self.numIterations
- seed = self.seed
-
- model = sc._jvm.PythonMLLibAPI().trainWord2Vec(
- _to_java_object_rdd(data), vectorSize,
- learningRate, numPartitions, numIterations, seed)
- return Word2VecModel(sc, model)
+ if not isinstance(data, RDD):
+ raise TypeError("data should be an RDD of list of string")
+ jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
+ float(self.learningRate), int(self.numPartitions),
+ int(self.numIterations), long(self.seed))
+ return Word2VecModel(jmodel)
def _test():
@@ -191,4 +420,5 @@ def _test():
exit(-1)
if __name__ == "__main__":
+ sys.path.pop(0)
_test()
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 773d8d393805d..7f21190ed8c25 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -29,9 +29,11 @@
import numpy as np
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
+ IntegerType, ByteType
-__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']
+
+__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices']
if sys.version_info[:2] == (2, 7):
@@ -52,17 +54,6 @@ def fast_pickle_array(ar):
_have_scipy = False
-# this will call the MLlib version of pythonToJava()
-def _to_java_object_rdd(rdd):
- """ Return an JavaRDD of Object by unpickling
-
- It will convert each Python object into Java object by Pyrolite, whenever the
- RDD is serialized in batch or not.
- """
- rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
- return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
-
-
def _convert_to_vector(l):
if isinstance(l, Vector):
return l
@@ -111,7 +102,61 @@ def _vector_size(v):
raise TypeError("Cannot treat type %s as a vector" % type(v))
+def _format_float(f, digits=4):
+ s = str(round(f, digits))
+ if '.' in s:
+ s = s[:s.index('.') + 1 + digits]
+ return s
+
+
+class VectorUDT(UserDefinedType):
+ """
+ SQL user-defined type (UDT) for Vector.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return StructType([
+ StructField("type", ByteType(), False),
+ StructField("size", IntegerType(), True),
+ StructField("indices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True)])
+
+ @classmethod
+ def module(cls):
+ return "pyspark.mllib.linalg"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.mllib.linalg.VectorUDT"
+
+ def serialize(self, obj):
+ if isinstance(obj, SparseVector):
+ indices = [int(i) for i in obj.indices]
+ values = [float(v) for v in obj.values]
+ return (0, obj.size, indices, values)
+ elif isinstance(obj, DenseVector):
+ values = [float(v) for v in obj]
+ return (1, None, None, values)
+ else:
+ raise ValueError("cannot serialize %r of type %r" % (obj, type(obj)))
+
+ def deserialize(self, datum):
+ assert len(datum) == 4, \
+ "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
+ tpe = datum[0]
+ if tpe == 0:
+ return SparseVector(datum[1], datum[2], datum[3])
+ elif tpe == 1:
+ return DenseVector(datum[3])
+ else:
+ raise ValueError("do not recognize type %r" % tpe)
+
+
class Vector(object):
+
+ __UDT__ = VectorUDT()
+
"""
Abstract class for DenseVector and SparseVector
"""
@@ -128,12 +173,16 @@ class DenseVector(Vector):
A dense vector represented by a value array.
"""
def __init__(self, ar):
- if not isinstance(ar, array.array):
- ar = array.array('d', ar)
+ if isinstance(ar, basestring):
+ ar = np.frombuffer(ar, dtype=np.float64)
+ elif not isinstance(ar, np.ndarray):
+ ar = np.array(ar, dtype=np.float64)
+ if ar.dtype != np.float64:
+ ar = ar.astype(np.float64)
self.array = ar
def __reduce__(self):
- return DenseVector, (self.array,)
+ return DenseVector, (self.array.tostring(),)
def dot(self, other):
"""
@@ -162,9 +211,10 @@ def dot(self, other):
...
AssertionError: dimension mismatch
"""
- if type(other) == np.ndarray and other.ndim > 1:
- assert len(self) == other.shape[0], "dimension mismatch"
- return np.dot(self.toArray(), other)
+ if type(other) == np.ndarray:
+ if other.ndim > 1:
+ assert len(self) == other.shape[0], "dimension mismatch"
+ return np.dot(self.array, other)
elif _have_scipy and scipy.sparse.issparse(other):
assert len(self) == other.shape[0], "dimension mismatch"
return other.transpose().dot(self.toArray())
@@ -216,7 +266,7 @@ def squared_distance(self, other):
return np.dot(diff, diff)
def toArray(self):
- return np.array(self.array)
+ return self.array
def __getitem__(self, item):
return self.array[item]
@@ -228,10 +278,10 @@ def __str__(self):
return "[" + ",".join([str(v) for v in self.array]) + "]"
def __repr__(self):
- return "DenseVector(%r)" % self.array
+ return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array))
def __eq__(self, other):
- return isinstance(other, DenseVector) and self.array == other.array
+ return isinstance(other, DenseVector) and np.array_equal(self.array, other.array)
def __ne__(self, other):
return not self == other
@@ -269,18 +319,28 @@ def __init__(self, size, *args):
if type(pairs) == dict:
pairs = pairs.items()
pairs = sorted(pairs)
- self.indices = array.array('i', [p[0] for p in pairs])
- self.values = array.array('d', [p[1] for p in pairs])
+ self.indices = np.array([p[0] for p in pairs], dtype=np.int32)
+ self.values = np.array([p[1] for p in pairs], dtype=np.float64)
else:
- assert len(args[0]) == len(args[1]), "index and value arrays not same length"
- self.indices = array.array('i', args[0])
- self.values = array.array('d', args[1])
+ if isinstance(args[0], basestring):
+ assert isinstance(args[1], str), "values should be string too"
+ if args[0]:
+ self.indices = np.frombuffer(args[0], np.int32)
+ self.values = np.frombuffer(args[1], np.float64)
+ else:
+ # np.frombuffer() doesn't work well with empty string in older version
+ self.indices = np.array([], dtype=np.int32)
+ self.values = np.array([], dtype=np.float64)
+ else:
+ self.indices = np.array(args[0], dtype=np.int32)
+ self.values = np.array(args[1], dtype=np.float64)
+ assert len(self.indices) == len(self.values), "index and value arrays not same length"
for i in xrange(len(self.indices) - 1):
if self.indices[i] >= self.indices[i + 1]:
raise TypeError("indices array must be sorted")
def __reduce__(self):
- return (SparseVector, (self.size, self.indices, self.values))
+ return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring()))
def dot(self, other):
"""
@@ -416,8 +476,7 @@ def toArray(self):
Returns a copy of this SparseVector as a 1-dimensional NumPy array.
"""
arr = np.zeros((self.size,), dtype=np.float64)
- for i in xrange(self.indices.size):
- arr[self.indices[i]] = self.values[i]
+ arr[self.indices] = self.values
return arr
def __len__(self):
@@ -431,7 +490,8 @@ def __str__(self):
def __repr__(self):
inds = self.indices
vals = self.values
- entries = ", ".join(["{0}: {1}".format(inds[i], vals[i]) for i in xrange(len(inds))])
+ entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i]))
+ for i in xrange(len(inds))])
return "SparseVector({0}, {{{1}}})".format(self.size, entries)
def __eq__(self, other):
@@ -447,8 +507,25 @@ def __eq__(self, other):
"""
return (isinstance(other, self.__class__)
and other.size == self.size
- and other.indices == self.indices
- and other.values == self.values)
+ and np.array_equal(other.indices, self.indices)
+ and np.array_equal(other.values, self.values))
+
+ def __getitem__(self, index):
+ inds = self.indices
+ vals = self.values
+ if not isinstance(index, int):
+ raise ValueError(
+ "Indices must be of type integer, got type %s" % type(index))
+ if index < 0:
+ index += self.size
+ if index >= self.size or index < 0:
+ raise ValueError("Index %d out of bounds." % index)
+
+ insert_index = np.searchsorted(inds, index)
+ row_ind = inds[insert_index]
+ if row_ind == index:
+ return vals[insert_index]
+ return 0.
def __ne__(self, other):
return not self.__eq__(other)
@@ -491,7 +568,7 @@ def dense(elements):
returns a NumPy array.
>>> Vectors.dense([1, 2, 3])
- DenseVector(array('d', [1.0, 2.0, 3.0]))
+ DenseVector([1.0, 2.0, 3.0])
"""
return DenseVector(elements)
@@ -531,23 +608,43 @@ class DenseMatrix(Matrix):
"""
def __init__(self, numRows, numCols, values):
Matrix.__init__(self, numRows, numCols)
+ if isinstance(values, basestring):
+ values = np.frombuffer(values, dtype=np.float64)
+ elif not isinstance(values, np.ndarray):
+ values = np.array(values, dtype=np.float64)
assert len(values) == numRows * numCols
+ if values.dtype != np.float64:
+ values.astype(np.float64)
self.values = values
def __reduce__(self):
- return DenseMatrix, (self.numRows, self.numCols, self.values)
+ return DenseMatrix, (self.numRows, self.numCols, self.values.tostring())
def toArray(self):
"""
Return an numpy.ndarray
- >>> arr = array.array('d', [float(i) for i in range(4)])
- >>> m = DenseMatrix(2, 2, arr)
+ >>> m = DenseMatrix(2, 2, range(4))
>>> m.toArray()
array([[ 0., 2.],
[ 1., 3.]])
"""
- return np.reshape(self.values, (self.numRows, self.numCols), order='F')
+ return self.values.reshape((self.numRows, self.numCols), order='F')
+
+ def __eq__(self, other):
+ return (isinstance(other, DenseMatrix) and
+ self.numRows == other.numRows and
+ self.numCols == other.numCols and
+ all(self.values == other.values))
+
+
+class Matrices(object):
+ @staticmethod
+ def dense(numRows, numCols, values):
+ """
+ Create a DenseMatrix
+ """
+ return DenseMatrix(numRows, numCols, values)
def _test():
@@ -557,8 +654,4 @@ def _test():
exit(-1)
if __name__ == "__main__":
- # remove current path from list of search paths to avoid importing mllib.random
- # for C{import random}, which is done in an external dependency of pyspark during doctests.
- import sys
- sys.path.pop(0)
_test()
diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/rand.py
new file mode 100644
index 0000000000000..20ee9d78bf5b0
--- /dev/null
+++ b/python/pyspark/mllib/rand.py
@@ -0,0 +1,410 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Python package for random data generation.
+"""
+
+from functools import wraps
+
+from pyspark.mllib.common import callMLlibFunc
+
+
+__all__ = ['RandomRDDs', ]
+
+
+def toArray(f):
+ @wraps(f)
+ def func(sc, *a, **kw):
+ rdd = f(sc, *a, **kw)
+ return rdd.map(lambda vec: vec.toArray())
+ return func
+
+
+class RandomRDDs(object):
+ """
+ Generator methods for creating RDDs comprised of i.i.d samples from
+ some distribution.
+ """
+
+ @staticmethod
+ def uniformRDD(sc, size, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of i.i.d. samples from the
+ uniform distribution U(0.0, 1.0).
+
+ To transform the distribution in the generated RDD from U(0.0, 1.0)
+ to U(a, b), use
+ C{RandomRDDs.uniformRDD(sc, n, p, seed)\
+ .map(lambda v: a + (b - a) * v)}
+
+ :param sc: SparkContext used to create the RDD.
+ :param size: Size of the RDD.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`.
+
+ >>> x = RandomRDDs.uniformRDD(sc, 100).collect()
+ >>> len(x)
+ 100
+ >>> max(x) <= 1.0 and min(x) >= 0.0
+ True
+ >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions()
+ 4
+ >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions()
+ >>> parts == sc.defaultParallelism
+ True
+ """
+ return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed)
+
+ @staticmethod
+ def normalRDD(sc, size, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of i.i.d. samples from the standard normal
+ distribution.
+
+ To transform the distribution in the generated RDD from standard normal
+ to some other normal N(mean, sigma^2), use
+ C{RandomRDDs.normal(sc, n, p, seed)\
+ .map(lambda v: mean + sigma * v)}
+
+ :param sc: SparkContext used to create the RDD.
+ :param size: Size of the RDD.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0).
+
+ >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L)
+ >>> stats = x.stats()
+ >>> stats.count()
+ 1000L
+ >>> abs(stats.mean() - 0.0) < 0.1
+ True
+ >>> abs(stats.stdev() - 1.0) < 0.1
+ True
+ """
+ return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed)
+
+ @staticmethod
+ def logNormalRDD(sc, mean, std, size, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of i.i.d. samples from the log normal
+ distribution with the input mean and standard distribution.
+
+ :param sc: SparkContext used to create the RDD.
+ :param mean: mean for the log Normal distribution
+ :param std: std for the log Normal distribution
+ :param size: Size of the RDD.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of float comprised of i.i.d. samples ~ log N(mean, std).
+
+ >>> from math import sqrt, exp
+ >>> mean = 0.0
+ >>> std = 1.0
+ >>> expMean = exp(mean + 0.5 * std * std)
+ >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
+ >>> x = RandomRDDs.logNormalRDD(sc, mean, std, 1000, seed=2L)
+ >>> stats = x.stats()
+ >>> stats.count()
+ 1000L
+ >>> abs(stats.mean() - expMean) < 0.5
+ True
+ >>> from math import sqrt
+ >>> abs(stats.stdev() - expStd) < 0.5
+ True
+ """
+ return callMLlibFunc("logNormalRDD", sc._jsc, float(mean), float(std),
+ size, numPartitions, seed)
+
+ @staticmethod
+ def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of i.i.d. samples from the Poisson
+ distribution with the input mean.
+
+ :param sc: SparkContext used to create the RDD.
+ :param mean: Mean, or lambda, for the Poisson distribution.
+ :param size: Size of the RDD.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of float comprised of i.i.d. samples ~ Pois(mean).
+
+ >>> mean = 100.0
+ >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L)
+ >>> stats = x.stats()
+ >>> stats.count()
+ 1000L
+ >>> abs(stats.mean() - mean) < 0.5
+ True
+ >>> from math import sqrt
+ >>> abs(stats.stdev() - sqrt(mean)) < 0.5
+ True
+ """
+ return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed)
+
+ @staticmethod
+ def exponentialRDD(sc, mean, size, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of i.i.d. samples from the Exponential
+ distribution with the input mean.
+
+ :param sc: SparkContext used to create the RDD.
+ :param mean: Mean, or 1 / lambda, for the Exponential distribution.
+ :param size: Size of the RDD.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of float comprised of i.i.d. samples ~ Exp(mean).
+
+ >>> mean = 2.0
+ >>> x = RandomRDDs.exponentialRDD(sc, mean, 1000, seed=2L)
+ >>> stats = x.stats()
+ >>> stats.count()
+ 1000L
+ >>> abs(stats.mean() - mean) < 0.5
+ True
+ >>> from math import sqrt
+ >>> abs(stats.stdev() - sqrt(mean)) < 0.5
+ True
+ """
+ return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed)
+
+ @staticmethod
+ def gammaRDD(sc, shape, scale, size, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of i.i.d. samples from the Gamma
+ distribution with the input shape and scale.
+
+ :param sc: SparkContext used to create the RDD.
+ :param shape: shape (> 0) parameter for the Gamma distribution
+ :param scale: scale (> 0) parameter for the Gamma distribution
+ :param size: Size of the RDD.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of float comprised of i.i.d. samples ~ Gamma(shape, scale).
+
+ >>> from math import sqrt
+ >>> shape = 1.0
+ >>> scale = 2.0
+ >>> expMean = shape * scale
+ >>> expStd = sqrt(shape * scale * scale)
+ >>> x = RandomRDDs.gammaRDD(sc, shape, scale, 1000, seed=2L)
+ >>> stats = x.stats()
+ >>> stats.count()
+ 1000L
+ >>> abs(stats.mean() - expMean) < 0.5
+ True
+ >>> abs(stats.stdev() - expStd) < 0.5
+ True
+ """
+ return callMLlibFunc("gammaRDD", sc._jsc, float(shape),
+ float(scale), size, numPartitions, seed)
+
+ @staticmethod
+ @toArray
+ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of vectors containing i.i.d. samples drawn
+ from the uniform distribution U(0.0, 1.0).
+
+ :param sc: SparkContext used to create the RDD.
+ :param numRows: Number of Vectors in the RDD.
+ :param numCols: Number of elements in each Vector.
+ :param numPartitions: Number of partitions in the RDD.
+ :param seed: Seed for the RNG that generates the seed for the generator in each partition.
+ :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`.
+
+ >>> import numpy as np
+ >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect())
+ >>> mat.shape
+ (10, 10)
+ >>> mat.max() <= 1.0 and mat.min() >= 0.0
+ True
+ >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions()
+ 4
+ """
+ return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
+
+ @staticmethod
+ @toArray
+ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of vectors containing i.i.d. samples drawn
+ from the standard normal distribution.
+
+ :param sc: SparkContext used to create the RDD.
+ :param numRows: Number of Vectors in the RDD.
+ :param numCols: Number of elements in each Vector.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`.
+
+ >>> import numpy as np
+ >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect())
+ >>> mat.shape
+ (100, 100)
+ >>> abs(mat.mean() - 0.0) < 0.1
+ True
+ >>> abs(mat.std() - 1.0) < 0.1
+ True
+ """
+ return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed)
+
+ @staticmethod
+ @toArray
+ def logNormalVectorRDD(sc, mean, std, numRows, numCols, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of vectors containing i.i.d. samples drawn
+ from the log normal distribution.
+
+ :param sc: SparkContext used to create the RDD.
+ :param mean: Mean of the log normal distribution
+ :param std: Standard Deviation of the log normal distribution
+ :param numRows: Number of Vectors in the RDD.
+ :param numCols: Number of elements in each Vector.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of Vector with vectors containing i.i.d. samples ~ log `N(mean, std)`.
+
+ >>> import numpy as np
+ >>> from math import sqrt, exp
+ >>> mean = 0.0
+ >>> std = 1.0
+ >>> expMean = exp(mean + 0.5 * std * std)
+ >>> expStd = sqrt((exp(std * std) - 1.0) * exp(2.0 * mean + std * std))
+ >>> mat = np.matrix(RandomRDDs.logNormalVectorRDD(sc, mean, std, \
+ 100, 100, seed=1L).collect())
+ >>> mat.shape
+ (100, 100)
+ >>> abs(mat.mean() - expMean) < 0.1
+ True
+ >>> abs(mat.std() - expStd) < 0.1
+ True
+ """
+ return callMLlibFunc("logNormalVectorRDD", sc._jsc, float(mean), float(std),
+ numRows, numCols, numPartitions, seed)
+
+ @staticmethod
+ @toArray
+ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of vectors containing i.i.d. samples drawn
+ from the Poisson distribution with the input mean.
+
+ :param sc: SparkContext used to create the RDD.
+ :param mean: Mean, or lambda, for the Poisson distribution.
+ :param numRows: Number of Vectors in the RDD.
+ :param numCols: Number of elements in each Vector.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`)
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean).
+
+ >>> import numpy as np
+ >>> mean = 100.0
+ >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L)
+ >>> mat = np.mat(rdd.collect())
+ >>> mat.shape
+ (100, 100)
+ >>> abs(mat.mean() - mean) < 0.5
+ True
+ >>> from math import sqrt
+ >>> abs(mat.std() - sqrt(mean)) < 0.5
+ True
+ """
+ return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols,
+ numPartitions, seed)
+
+ @staticmethod
+ @toArray
+ def exponentialVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of vectors containing i.i.d. samples drawn
+ from the Exponential distribution with the input mean.
+
+ :param sc: SparkContext used to create the RDD.
+ :param mean: Mean, or 1 / lambda, for the Exponential distribution.
+ :param numRows: Number of Vectors in the RDD.
+ :param numCols: Number of elements in each Vector.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`)
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of Vector with vectors containing i.i.d. samples ~ Exp(mean).
+
+ >>> import numpy as np
+ >>> mean = 0.5
+ >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1L)
+ >>> mat = np.mat(rdd.collect())
+ >>> mat.shape
+ (100, 100)
+ >>> abs(mat.mean() - mean) < 0.5
+ True
+ >>> from math import sqrt
+ >>> abs(mat.std() - sqrt(mean)) < 0.5
+ True
+ """
+ return callMLlibFunc("exponentialVectorRDD", sc._jsc, float(mean), numRows, numCols,
+ numPartitions, seed)
+
+ @staticmethod
+ @toArray
+ def gammaVectorRDD(sc, shape, scale, numRows, numCols, numPartitions=None, seed=None):
+ """
+ Generates an RDD comprised of vectors containing i.i.d. samples drawn
+ from the Gamma distribution.
+
+ :param sc: SparkContext used to create the RDD.
+ :param shape: Shape (> 0) of the Gamma distribution
+ :param scale: Scale (> 0) of the Gamma distribution
+ :param numRows: Number of Vectors in the RDD.
+ :param numCols: Number of elements in each Vector.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of Vector with vectors containing i.i.d. samples ~ Gamma(shape, scale).
+
+ >>> import numpy as np
+ >>> from math import sqrt
+ >>> shape = 1.0
+ >>> scale = 2.0
+ >>> expMean = shape * scale
+ >>> expStd = sqrt(shape * scale * scale)
+ >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, \
+ 100, 100, seed=1L).collect())
+ >>> mat.shape
+ (100, 100)
+ >>> abs(mat.mean() - expMean) < 0.1
+ True
+ >>> abs(mat.std() - expStd) < 0.1
+ True
+ """
+ return callMLlibFunc("gammaVectorRDD", sc._jsc, float(shape), float(scale),
+ numRows, numCols, numPartitions, seed)
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ globs = globals().copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py
deleted file mode 100644
index 73baba4ace5f6..0000000000000
--- a/python/pyspark/mllib/random.py
+++ /dev/null
@@ -1,200 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-Python package for random data generation.
-"""
-
-from functools import wraps
-
-from pyspark.rdd import RDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer
-
-
-__all__ = ['RandomRDDs', ]
-
-
-def serialize(f):
- @wraps(f)
- def func(sc, *a, **kw):
- jrdd = f(sc, *a, **kw)
- return RDD(sc._jvm.SerDe.javaToPython(jrdd), sc,
- BatchedSerializer(PickleSerializer(), 1024))
- return func
-
-
-def toArray(f):
- @wraps(f)
- def func(sc, *a, **kw):
- rdd = f(sc, *a, **kw)
- return rdd.map(lambda vec: vec.toArray())
- return func
-
-
-class RandomRDDs(object):
- """
- Generator methods for creating RDDs comprised of i.i.d samples from
- some distribution.
- """
-
- @staticmethod
- @serialize
- def uniformRDD(sc, size, numPartitions=None, seed=None):
- """
- Generates an RDD comprised of i.i.d. samples from the
- uniform distribution U(0.0, 1.0).
-
- To transform the distribution in the generated RDD from U(0.0, 1.0)
- to U(a, b), use
- C{RandomRDDs.uniformRDD(sc, n, p, seed)\
- .map(lambda v: a + (b - a) * v)}
-
- >>> x = RandomRDDs.uniformRDD(sc, 100).collect()
- >>> len(x)
- 100
- >>> max(x) <= 1.0 and min(x) >= 0.0
- True
- >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions()
- 4
- >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions()
- >>> parts == sc.defaultParallelism
- True
- """
- return sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed)
-
- @staticmethod
- @serialize
- def normalRDD(sc, size, numPartitions=None, seed=None):
- """
- Generates an RDD comprised of i.i.d. samples from the standard normal
- distribution.
-
- To transform the distribution in the generated RDD from standard normal
- to some other normal N(mean, sigma^2), use
- C{RandomRDDs.normal(sc, n, p, seed)\
- .map(lambda v: mean + sigma * v)}
-
- >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L)
- >>> stats = x.stats()
- >>> stats.count()
- 1000L
- >>> abs(stats.mean() - 0.0) < 0.1
- True
- >>> abs(stats.stdev() - 1.0) < 0.1
- True
- """
- return sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed)
-
- @staticmethod
- @serialize
- def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
- """
- Generates an RDD comprised of i.i.d. samples from the Poisson
- distribution with the input mean.
-
- >>> mean = 100.0
- >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=1L)
- >>> stats = x.stats()
- >>> stats.count()
- 1000L
- >>> abs(stats.mean() - mean) < 0.5
- True
- >>> from math import sqrt
- >>> abs(stats.stdev() - sqrt(mean)) < 0.5
- True
- """
- return sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed)
-
- @staticmethod
- @toArray
- @serialize
- def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
- """
- Generates an RDD comprised of vectors containing i.i.d. samples drawn
- from the uniform distribution U(0.0, 1.0).
-
- >>> import numpy as np
- >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect())
- >>> mat.shape
- (10, 10)
- >>> mat.max() <= 1.0 and mat.min() >= 0.0
- True
- >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions()
- 4
- """
- return sc._jvm.PythonMLLibAPI() \
- .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed)
-
- @staticmethod
- @toArray
- @serialize
- def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
- """
- Generates an RDD comprised of vectors containing i.i.d. samples drawn
- from the standard normal distribution.
-
- >>> import numpy as np
- >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect())
- >>> mat.shape
- (100, 100)
- >>> abs(mat.mean() - 0.0) < 0.1
- True
- >>> abs(mat.std() - 1.0) < 0.1
- True
- """
- return sc._jvm.PythonMLLibAPI() \
- .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed)
-
- @staticmethod
- @toArray
- @serialize
- def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
- """
- Generates an RDD comprised of vectors containing i.i.d. samples drawn
- from the Poisson distribution with the input mean.
-
- >>> import numpy as np
- >>> mean = 100.0
- >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L)
- >>> mat = np.mat(rdd.collect())
- >>> mat.shape
- (100, 100)
- >>> abs(mat.mean() - mean) < 0.5
- True
- >>> from math import sqrt
- >>> abs(mat.std() - sqrt(mean)) < 0.5
- True
- """
- return sc._jvm.PythonMLLibAPI() \
- .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed)
-
-
-def _test():
- import doctest
- from pyspark.context import SparkContext
- globs = globals().copy()
- # The small batch size here ensures that we see multiple batches,
- # even in these small test examples:
- globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2)
- (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- globs['sc'].stop()
- if failure_count:
- exit(-1)
-
-
-if __name__ == "__main__":
- _test()
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 22872dbbe3b55..0d99e6dedfad9 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -15,28 +15,31 @@
# limitations under the License.
#
+from collections import namedtuple
+
from pyspark import SparkContext
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.rdd import RDD
-from pyspark.mllib.linalg import _to_java_object_rdd
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
-__all__ = ['MatrixFactorizationModel', 'ALS']
+__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
-class Rating(object):
- def __init__(self, user, product, rating):
- self.user = int(user)
- self.product = int(product)
- self.rating = float(rating)
+class Rating(namedtuple("Rating", ["user", "product", "rating"])):
+ """
+ Represents a (user, product, rating) tuple.
- def __reduce__(self):
- return Rating, (self.user, self.product, self.rating)
+ >>> r = Rating(1, 2, 5.0)
+ >>> (r.user, r.product, r.rating)
+ (1, 2, 5.0)
+ >>> (r[0], r[1], r[2])
+ (1, 2, 5.0)
+ """
- def __repr__(self):
- return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating)
+ def __reduce__(self):
+ return Rating, (int(self.user), int(self.product), float(self.rating))
-class MatrixFactorizationModel(object):
+class MatrixFactorizationModel(JavaModelWrapper):
"""A matrix factorisation model trained by regularized alternating
least-squares.
@@ -45,74 +48,55 @@ class MatrixFactorizationModel(object):
>>> r2 = (1, 2, 2.0)
>>> r3 = (2, 1, 2.0)
>>> ratings = sc.parallelize([r1, r2, r3])
- >>> model = ALS.trainImplicit(ratings, 1)
- >>> model.predict(2,2) is not None
- True
+ >>> model = ALS.trainImplicit(ratings, 1, seed=10)
+ >>> model.predict(2, 2)
+ 0.43...
>>> testset = sc.parallelize([(1, 2), (1, 1)])
- >>> model = ALS.train(ratings, 1)
- >>> model.predictAll(testset).count() == 2
- True
+ >>> model = ALS.train(ratings, 2, seed=0)
+ >>> model.predictAll(testset).collect()
+ [Rating(user=1, product=1, rating=1.0...), Rating(user=1, product=2, rating=1.9...)]
- >>> model = ALS.train(ratings, 4)
- >>> model.userFeatures().count() == 2
- True
+ >>> model = ALS.train(ratings, 4, seed=10)
+ >>> model.userFeatures().collect()
+ [(1, array('d', [...])), (2, array('d', [...]))]
>>> first_user = model.userFeatures().take(1)[0]
>>> latents = first_user[1]
>>> len(latents) == 4
True
- >>> model.productFeatures().count() == 2
- True
+ >>> model.productFeatures().collect()
+ [(1, array('d', [...])), (2, array('d', [...]))]
>>> first_product = model.productFeatures().take(1)[0]
>>> latents = first_product[1]
>>> len(latents) == 4
True
- """
-
- def __init__(self, sc, java_model):
- self._context = sc
- self._java_model = java_model
- def __del__(self):
- self._context._gateway.detach(self._java_model)
+ >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
+ >>> model.predict(2,2)
+ 3.8...
+ >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
+ >>> model.predict(2,2)
+ 0.43...
+ """
def predict(self, user, product):
- return self._java_model.predict(user, product)
+ return self._java_model.predict(int(user), int(product))
def predictAll(self, user_product):
assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)"
first = user_product.first()
- if isinstance(first, list):
- user_product = user_product.map(tuple)
- first = tuple(first)
- assert type(first) is tuple and len(first) == 2, \
- "user_product should be RDD of (user, product)"
- if any(isinstance(x, str) for x in first):
- user_product = user_product.map(lambda (u, p): (int(x), int(p)))
- first = tuple(map(int, first))
- assert all(type(x) is int for x in first), "user and product in user_product shoul be int"
- sc = self._context
- tuplerdd = sc._jvm.SerDe.asTupleRDD(_to_java_object_rdd(user_product).rdd())
- jresult = self._java_model.predict(tuplerdd).toJavaRDD()
- return RDD(sc._jvm.SerDe.javaToPython(jresult), sc,
- AutoBatchedSerializer(PickleSerializer()))
+ assert len(first) == 2, "user_product should be RDD of (user, product)"
+ user_product = user_product.map(lambda (u, p): (int(u), int(p)))
+ return self.call("predict", user_product)
def userFeatures(self):
- sc = self._context
- juf = self._java_model.userFeatures()
- juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD()
- return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc,
- AutoBatchedSerializer(PickleSerializer()))
+ return self.call("getUserFeatures")
def productFeatures(self):
- sc = self._context
- jpf = self._java_model.productFeatures()
- jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD()
- return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc,
- AutoBatchedSerializer(PickleSerializer()))
+ return self.call("getProductFeatures")
class ALS(object):
@@ -126,32 +110,28 @@ def _prepare(cls, ratings):
ratings = ratings.map(lambda x: Rating(*x))
else:
raise ValueError("rating should be RDD of Rating or tuple/list")
- # serialize them by AutoBatchedSerializer before cache to reduce the
- # objects overhead in JVM
- cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache()
- return _to_java_object_rdd(cached)
+ return ratings
@classmethod
- def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
- sc = ratings.context
- jrating = cls._prepare(ratings)
- mod = sc._jvm.PythonMLLibAPI().trainALSModel(jrating, rank, iterations, lambda_, blocks)
- return MatrixFactorizationModel(sc, mod)
+ def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False,
+ seed=None):
+ model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations,
+ lambda_, blocks, nonnegative, seed)
+ return MatrixFactorizationModel(model)
@classmethod
- def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
- sc = ratings.context
- jrating = cls._prepare(ratings)
- mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(
- jrating, rank, iterations, lambda_, blocks, alpha)
- return MatrixFactorizationModel(sc, mod)
+ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01,
+ nonnegative=False, seed=None):
+ model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank,
+ iterations, lambda_, blocks, alpha, nonnegative, seed)
+ return MatrixFactorizationModel(model)
def _test():
import doctest
import pyspark.mllib.recommendation
globs = pyspark.mllib.recommendation.__dict__.copy()
- globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 93e17faf5cd51..210060140fd91 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,9 +18,8 @@
import numpy as np
from numpy import array
-from pyspark import SparkContext
-from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector
__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD']
@@ -37,7 +36,7 @@ class LabeledPoint(object):
"""
def __init__(self, label, features):
- self.label = label
+ self.label = float(label)
self.features = _convert_to_vector(features)
def __reduce__(self):
@@ -47,7 +46,7 @@ def __str__(self):
return "(" + ",".join((str(self.label), str(self.features))) + ")"
def __repr__(self):
- return "LabeledPoint(" + ",".join((repr(self.label), repr(self.features))) + ")"
+ return "LabeledPoint(%s, %s)" % (self.label, self.features)
class LinearModel(object):
@@ -56,7 +55,7 @@ class LinearModel(object):
def __init__(self, weights, intercept):
self._coeff = _convert_to_vector(weights)
- self._intercept = intercept
+ self._intercept = float(intercept)
@property
def weights(self):
@@ -67,7 +66,7 @@ def intercept(self):
return self._intercept
def __repr__(self):
- return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept)
+ return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept)
class LinearRegressionModelBase(LinearModel):
@@ -86,6 +85,7 @@ def predict(self, x):
Predict the value of the dependent variable given a vector x
containing values for the independent variables.
"""
+ x = _convert_to_vector(x)
return self.weights.dot(x) + self.intercept
@@ -124,24 +124,20 @@ class LinearRegressionModel(LinearRegressionModelBase):
# train_func should take two parameters, namely data and initial_weights, and
# return the result of a call to the appropriate JVM stub.
# _regression_train_wrapper is responsible for setup and error checking.
-def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights):
+def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
+ first = data.first()
+ if not isinstance(first, LabeledPoint):
+ raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
initial_weights = initial_weights or [0.0] * len(data.first().features)
- ser = PickleSerializer()
- initial_bytes = bytearray(ser.dumps(_convert_to_vector(initial_weights)))
- # use AutoBatchedSerializer before cache to reduce the memory
- # overhead in JVM
- cached = data._reserialize(AutoBatchedSerializer(ser)).cache()
- ans = train_func(_to_java_object_rdd(cached), initial_bytes)
- assert len(ans) == 2, "JVM call result had unexpected length"
- weights = ser.loads(str(ans[0]))
- return modelClass(weights, ans[1])
+ weights, intercept = train_func(data, _convert_to_vector(initial_weights))
+ return modelClass(weights, intercept)
class LinearRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
- initialWeights=None, regParam=1.0, regType="none", intercept=False):
+ initialWeights=None, regParam=0.0, regType=None, intercept=False):
"""
Train a linear regression model on the given data.
@@ -152,29 +148,28 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
:param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
:param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter (default: 1.0).
+ :param regParam: The regularizer parameter (default: 0.0).
:param regType: The type of regularizer used for training
our model.
:Allowed values:
- - "l1" for using L1Updater,
- - "l2" for using SquaredL2Updater,
- - "none" for no regularizer.
+ - "l1" for using L1 regularization (lasso),
+ - "l2" for using L2 regularization (ridge),
+ - None for no regularization
- (default: "none")
+ (default: None)
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
- sc = data.context
+ def train(rdd, i):
+ return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations),
+ float(step), float(miniBatchFraction), i, float(regParam),
+ regType, bool(intercept))
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD(
- jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept)
-
- return _regression_train_wrapper(sc, train, LinearRegressionModel, data, initialWeights)
+ return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights)
class LassoModel(LinearRegressionModelBase):
@@ -213,15 +208,14 @@ class LassoModel(LinearRegressionModelBase):
class LassoWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, regParam=1.0,
+ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None):
"""Train a Lasso regression model on the given data."""
- sc = data.context
+ def train(rdd, i):
+ return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step),
+ float(regParam), float(miniBatchFraction), i)
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i)
- return _regression_train_wrapper(sc, train, LassoModel, data, initialWeights)
+ return _regression_train_wrapper(train, LassoModel, data, initialWeights)
class RidgeRegressionModel(LinearRegressionModelBase):
@@ -260,21 +254,21 @@ class RidgeRegressionModel(LinearRegressionModelBase):
class RidgeRegressionWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, regParam=1.0,
+ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None):
"""Train a ridge regression model on the given data."""
- sc = data.context
-
- def train(jrdd, i):
- return sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD(
- jrdd, iterations, step, regParam, miniBatchFraction, i)
+ def train(rdd, i):
+ return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step),
+ float(regParam), float(miniBatchFraction), i)
- return _regression_train_wrapper(sc, train, RidgeRegressionModel, data, initialWeights)
+ return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights)
def _test():
import doctest
- globs = globals().copy()
+ from pyspark import SparkContext
+ import pyspark.mllib.regression
+ globs = pyspark.mllib.regression.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
deleted file mode 100644
index 84baf12b906df..0000000000000
--- a/python/pyspark/mllib/stat.py
+++ /dev/null
@@ -1,190 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-Python package for statistical functions in MLlib.
-"""
-
-from functools import wraps
-
-from pyspark import PickleSerializer
-from pyspark.mllib.linalg import _convert_to_vector, _to_java_object_rdd
-
-
-__all__ = ['MultivariateStatisticalSummary', 'Statistics']
-
-
-def serialize(f):
- ser = PickleSerializer()
-
- @wraps(f)
- def func(self):
- jvec = f(self)
- bytes = self._sc._jvm.SerDe.dumps(jvec)
- return ser.loads(str(bytes)).toArray()
-
- return func
-
-
-class MultivariateStatisticalSummary(object):
-
- """
- Trait for multivariate statistical summary of a data matrix.
- """
-
- def __init__(self, sc, java_summary):
- """
- :param sc: Spark context
- :param java_summary: Handle to Java summary object
- """
- self._sc = sc
- self._java_summary = java_summary
-
- def __del__(self):
- self._sc._gateway.detach(self._java_summary)
-
- @serialize
- def mean(self):
- return self._java_summary.mean()
-
- @serialize
- def variance(self):
- return self._java_summary.variance()
-
- def count(self):
- return self._java_summary.count()
-
- @serialize
- def numNonzeros(self):
- return self._java_summary.numNonzeros()
-
- @serialize
- def max(self):
- return self._java_summary.max()
-
- @serialize
- def min(self):
- return self._java_summary.min()
-
-
-class Statistics(object):
-
- @staticmethod
- def colStats(rdd):
- """
- Computes column-wise summary statistics for the input RDD[Vector].
-
- >>> from pyspark.mllib.linalg import Vectors
- >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]),
- ... Vectors.dense([4, 5, 0, 3]),
- ... Vectors.dense([6, 7, 0, 8])])
- >>> cStats = Statistics.colStats(rdd)
- >>> cStats.mean()
- array([ 4., 4., 0., 3.])
- >>> cStats.variance()
- array([ 4., 13., 0., 25.])
- >>> cStats.count()
- 3L
- >>> cStats.numNonzeros()
- array([ 3., 2., 0., 3.])
- >>> cStats.max()
- array([ 6., 7., 0., 8.])
- >>> cStats.min()
- array([ 2., 0., 0., -2.])
- """
- sc = rdd.ctx
- jrdd = _to_java_object_rdd(rdd.map(_convert_to_vector))
- cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd)
- return MultivariateStatisticalSummary(sc, cStats)
-
- @staticmethod
- def corr(x, y=None, method=None):
- """
- Compute the correlation (matrix) for the input RDD(s) using the
- specified method.
- Methods currently supported: I{pearson (default), spearman}.
-
- If a single RDD of Vectors is passed in, a correlation matrix
- comparing the columns in the input RDD is returned. Use C{method=}
- to specify the method to be used for single RDD inout.
- If two RDDs of floats are passed in, a single float is returned.
-
- >>> x = sc.parallelize([1.0, 0.0, -2.0], 2)
- >>> y = sc.parallelize([4.0, 5.0, 3.0], 2)
- >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2)
- >>> abs(Statistics.corr(x, y) - 0.6546537) < 1e-7
- True
- >>> Statistics.corr(x, y) == Statistics.corr(x, y, "pearson")
- True
- >>> Statistics.corr(x, y, "spearman")
- 0.5
- >>> from math import isnan
- >>> isnan(Statistics.corr(x, zeros))
- True
- >>> from pyspark.mllib.linalg import Vectors
- >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]),
- ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])])
- >>> pearsonCorr = Statistics.corr(rdd)
- >>> print str(pearsonCorr).replace('nan', 'NaN')
- [[ 1. 0.05564149 NaN 0.40047142]
- [ 0.05564149 1. NaN 0.91359586]
- [ NaN NaN 1. NaN]
- [ 0.40047142 0.91359586 NaN 1. ]]
- >>> spearmanCorr = Statistics.corr(rdd, method="spearman")
- >>> print str(spearmanCorr).replace('nan', 'NaN')
- [[ 1. 0.10540926 NaN 0.4 ]
- [ 0.10540926 1. NaN 0.9486833 ]
- [ NaN NaN 1. NaN]
- [ 0.4 0.9486833 NaN 1. ]]
- >>> try:
- ... Statistics.corr(rdd, "spearman")
- ... print "Method name as second argument without 'method=' shouldn't be allowed."
- ... except TypeError:
- ... pass
- """
- sc = x.ctx
- # Check inputs to determine whether a single value or a matrix is needed for output.
- # Since it's legal for users to use the method name as the second argument, we need to
- # check if y is used to specify the method name instead.
- if type(y) == str:
- raise TypeError("Use 'method=' to specify method name.")
-
- if not y:
- jx = _to_java_object_rdd(x.map(_convert_to_vector))
- resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method)
- bytes = sc._jvm.SerDe.dumps(resultMat)
- ser = PickleSerializer()
- return ser.loads(str(bytes)).toArray()
- else:
- jx = _to_java_object_rdd(x.map(float))
- jy = _to_java_object_rdd(y.map(float))
- return sc._jvm.PythonMLLibAPI().corr(jx, jy, method)
-
-
-def _test():
- import doctest
- from pyspark import SparkContext
- globs = globals().copy()
- globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
- (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- globs['sc'].stop()
- if failure_count:
- exit(-1)
-
-
-if __name__ == "__main__":
- _test()
diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py
new file mode 100644
index 0000000000000..b686d955a0080
--- /dev/null
+++ b/python/pyspark/mllib/stat/__init__.py
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Python package for statistical functions in MLlib.
+"""
+
+from pyspark.mllib.stat._statistics import *
+from pyspark.mllib.stat.distribution import MultivariateGaussian
+
+__all__ = ["Statistics", "MultivariateStatisticalSummary", "MultivariateGaussian"]
diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py
new file mode 100644
index 0000000000000..218ac148ca992
--- /dev/null
+++ b/python/pyspark/mllib/stat/_statistics.py
@@ -0,0 +1,247 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark import RDD
+from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
+from pyspark.mllib.linalg import Matrix, _convert_to_vector
+from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.stat.test import ChiSqTestResult
+
+
+__all__ = ['MultivariateStatisticalSummary', 'Statistics']
+
+
+class MultivariateStatisticalSummary(JavaModelWrapper):
+
+ """
+ Trait for multivariate statistical summary of a data matrix.
+ """
+
+ def mean(self):
+ return self.call("mean").toArray()
+
+ def variance(self):
+ return self.call("variance").toArray()
+
+ def count(self):
+ return self.call("count")
+
+ def numNonzeros(self):
+ return self.call("numNonzeros").toArray()
+
+ def max(self):
+ return self.call("max").toArray()
+
+ def min(self):
+ return self.call("min").toArray()
+
+
+class Statistics(object):
+
+ @staticmethod
+ def colStats(rdd):
+ """
+ Computes column-wise summary statistics for the input RDD[Vector].
+
+ :param rdd: an RDD[Vector] for which column-wise summary statistics
+ are to be computed.
+ :return: :class:`MultivariateStatisticalSummary` object containing
+ column-wise summary statistics.
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]),
+ ... Vectors.dense([4, 5, 0, 3]),
+ ... Vectors.dense([6, 7, 0, 8])])
+ >>> cStats = Statistics.colStats(rdd)
+ >>> cStats.mean()
+ array([ 4., 4., 0., 3.])
+ >>> cStats.variance()
+ array([ 4., 13., 0., 25.])
+ >>> cStats.count()
+ 3L
+ >>> cStats.numNonzeros()
+ array([ 3., 2., 0., 3.])
+ >>> cStats.max()
+ array([ 6., 7., 0., 8.])
+ >>> cStats.min()
+ array([ 2., 0., 0., -2.])
+ """
+ cStats = callMLlibFunc("colStats", rdd.map(_convert_to_vector))
+ return MultivariateStatisticalSummary(cStats)
+
+ @staticmethod
+ def corr(x, y=None, method=None):
+ """
+ Compute the correlation (matrix) for the input RDD(s) using the
+ specified method.
+ Methods currently supported: I{pearson (default), spearman}.
+
+ If a single RDD of Vectors is passed in, a correlation matrix
+ comparing the columns in the input RDD is returned. Use C{method=}
+ to specify the method to be used for single RDD inout.
+ If two RDDs of floats are passed in, a single float is returned.
+
+ :param x: an RDD of vector for which the correlation matrix is to be computed,
+ or an RDD of float of the same cardinality as y when y is specified.
+ :param y: an RDD of float of the same cardinality as x.
+ :param method: String specifying the method to use for computing correlation.
+ Supported: `pearson` (default), `spearman`
+ :return: Correlation matrix comparing columns in x.
+
+ >>> x = sc.parallelize([1.0, 0.0, -2.0], 2)
+ >>> y = sc.parallelize([4.0, 5.0, 3.0], 2)
+ >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2)
+ >>> abs(Statistics.corr(x, y) - 0.6546537) < 1e-7
+ True
+ >>> Statistics.corr(x, y) == Statistics.corr(x, y, "pearson")
+ True
+ >>> Statistics.corr(x, y, "spearman")
+ 0.5
+ >>> from math import isnan
+ >>> isnan(Statistics.corr(x, zeros))
+ True
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> rdd = sc.parallelize([Vectors.dense([1, 0, 0, -2]), Vectors.dense([4, 5, 0, 3]),
+ ... Vectors.dense([6, 7, 0, 8]), Vectors.dense([9, 0, 0, 1])])
+ >>> pearsonCorr = Statistics.corr(rdd)
+ >>> print str(pearsonCorr).replace('nan', 'NaN')
+ [[ 1. 0.05564149 NaN 0.40047142]
+ [ 0.05564149 1. NaN 0.91359586]
+ [ NaN NaN 1. NaN]
+ [ 0.40047142 0.91359586 NaN 1. ]]
+ >>> spearmanCorr = Statistics.corr(rdd, method="spearman")
+ >>> print str(spearmanCorr).replace('nan', 'NaN')
+ [[ 1. 0.10540926 NaN 0.4 ]
+ [ 0.10540926 1. NaN 0.9486833 ]
+ [ NaN NaN 1. NaN]
+ [ 0.4 0.9486833 NaN 1. ]]
+ >>> try:
+ ... Statistics.corr(rdd, "spearman")
+ ... print "Method name as second argument without 'method=' shouldn't be allowed."
+ ... except TypeError:
+ ... pass
+ """
+ # Check inputs to determine whether a single value or a matrix is needed for output.
+ # Since it's legal for users to use the method name as the second argument, we need to
+ # check if y is used to specify the method name instead.
+ if type(y) == str:
+ raise TypeError("Use 'method=' to specify method name.")
+
+ if not y:
+ return callMLlibFunc("corr", x.map(_convert_to_vector), method).toArray()
+ else:
+ return callMLlibFunc("corr", x.map(float), y.map(float), method)
+
+ @staticmethod
+ def chiSqTest(observed, expected=None):
+ """
+ .. note:: Experimental
+
+ If `observed` is Vector, conduct Pearson's chi-squared goodness
+ of fit test of the observed data against the expected distribution,
+ or againt the uniform distribution (by default), with each category
+ having an expected frequency of `1 / len(observed)`.
+ (Note: `observed` cannot contain negative values)
+
+ If `observed` is matrix, conduct Pearson's independence test on the
+ input contingency matrix, which cannot contain negative entries or
+ columns or rows that sum up to 0.
+
+ If `observed` is an RDD of LabeledPoint, conduct Pearson's independence
+ test for every feature against the label across the input RDD.
+ For each feature, the (feature, label) pairs are converted into a
+ contingency matrix for which the chi-squared statistic is computed.
+ All label and feature values must be categorical.
+
+ :param observed: it could be a vector containing the observed categorical
+ counts/relative frequencies, or the contingency matrix
+ (containing either counts or relative frequencies),
+ or an RDD of LabeledPoint containing the labeled dataset
+ with categorical features. Real-valued features will be
+ treated as categorical for each distinct value.
+ :param expected: Vector containing the expected categorical counts/relative
+ frequencies. `expected` is rescaled if the `expected` sum
+ differs from the `observed` sum.
+ :return: ChiSquaredTest object containing the test statistic, degrees
+ of freedom, p-value, the method used, and the null hypothesis.
+
+ >>> from pyspark.mllib.linalg import Vectors, Matrices
+ >>> observed = Vectors.dense([4, 6, 5])
+ >>> pearson = Statistics.chiSqTest(observed)
+ >>> print pearson.statistic
+ 0.4
+ >>> pearson.degreesOfFreedom
+ 2
+ >>> print round(pearson.pValue, 4)
+ 0.8187
+ >>> pearson.method
+ u'pearson'
+ >>> pearson.nullHypothesis
+ u'observed follows the same distribution as expected.'
+
+ >>> observed = Vectors.dense([21, 38, 43, 80])
+ >>> expected = Vectors.dense([3, 5, 7, 20])
+ >>> pearson = Statistics.chiSqTest(observed, expected)
+ >>> print round(pearson.pValue, 4)
+ 0.0027
+
+ >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
+ >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data))
+ >>> print round(chi.statistic, 4)
+ 21.9958
+
+ >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])),
+ ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])),
+ ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])),
+ ... LabeledPoint(0.0, Vectors.dense([3.5, 30.0])),
+ ... LabeledPoint(0.0, Vectors.dense([3.5, 40.0])),
+ ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),]
+ >>> rdd = sc.parallelize(data, 4)
+ >>> chi = Statistics.chiSqTest(rdd)
+ >>> print chi[0].statistic
+ 0.75
+ >>> print chi[1].statistic
+ 1.5
+ """
+ if isinstance(observed, RDD):
+ if not isinstance(observed.first(), LabeledPoint):
+ raise ValueError("observed should be an RDD of LabeledPoint")
+ jmodels = callMLlibFunc("chiSqTest", observed)
+ return [ChiSqTestResult(m) for m in jmodels]
+
+ if isinstance(observed, Matrix):
+ jmodel = callMLlibFunc("chiSqTest", observed)
+ else:
+ if expected and len(expected) != len(observed):
+ raise ValueError("`expected` should have same length with `observed`")
+ jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected)
+ return ChiSqTestResult(jmodel)
+
+
+def _test():
+ import doctest
+ from pyspark import SparkContext
+ globs = globals().copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py
new file mode 100644
index 0000000000000..07792e1532046
--- /dev/null
+++ b/python/pyspark/mllib/stat/distribution.py
@@ -0,0 +1,31 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import namedtuple
+
+__all__ = ['MultivariateGaussian']
+
+
+class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])):
+
+ """ Represents a (mu, sigma) tuple
+ >>> m = MultivariateGaussian(Vectors.dense([11,12]),DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0)))
+ >>> (m.mu, m.sigma.toArray())
+ (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]]))
+ >>> (m[0], m[1])
+ (DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]]))
+ """
diff --git a/python/pyspark/mllib/stat/test.py b/python/pyspark/mllib/stat/test.py
new file mode 100644
index 0000000000000..762506e952b43
--- /dev/null
+++ b/python/pyspark/mllib/stat/test.py
@@ -0,0 +1,69 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.mllib.common import JavaModelWrapper
+
+
+__all__ = ["ChiSqTestResult"]
+
+
+class ChiSqTestResult(JavaModelWrapper):
+ """
+ .. note:: Experimental
+
+ Object containing the test results for the chi-squared hypothesis test.
+ """
+ @property
+ def method(self):
+ """
+ Name of the test method
+ """
+ return self._java_model.method()
+
+ @property
+ def pValue(self):
+ """
+ The probability of obtaining a test statistic result at least as
+ extreme as the one that was actually observed, assuming that the
+ null hypothesis is true.
+ """
+ return self._java_model.pValue()
+
+ @property
+ def degreesOfFreedom(self):
+ """
+ Returns the degree(s) of freedom of the hypothesis test.
+ Return type should be Number(e.g. Int, Double) or tuples of Numbers.
+ """
+ return self._java_model.degreesOfFreedom()
+
+ @property
+ def statistic(self):
+ """
+ Test statistic.
+ """
+ return self._java_model.statistic()
+
+ @property
+ def nullHypothesis(self):
+ """
+ Null hypothesis of the test.
+ """
+ return self._java_model.nullHypothesis()
+
+ def __str__(self):
+ return self._java_model.toString()
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d6fb87b378b4a..49e5c9d58e5db 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -23,6 +23,7 @@
import array as pyarray
from numpy import array, array_equal
+from py4j.protocol import Py4JJavaError
if sys.version_info[:2] <= (2, 6):
try:
@@ -33,14 +34,15 @@
else:
import unittest
-from pyspark.serializers import PickleSerializer
-from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
+ DenseMatrix, Vectors, Matrices
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.serializers import PickleSerializer
+from pyspark.sql import SQLContext
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-
_have_scipy = False
try:
import scipy.sparse
@@ -62,6 +64,7 @@ def _squared_distance(a, b):
class VectorTests(PySparkTestCase):
def _test_serialize(self, v):
+ self.assertEqual(v, ser.loads(ser.dumps(v)))
jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec)))
self.assertEqual(v, nv)
@@ -75,6 +78,8 @@ def test_serialize(self):
self._test_serialize(DenseVector(array([1., 2., 3., 4.])))
self._test_serialize(DenseVector(pyarray.array('d', range(10))))
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
+ self._test_serialize(SparseVector(3, {}))
+ self._test_serialize(DenseMatrix(2, 3, range(6)))
def test_dot(self):
sv = SparseVector(4, {1: 1, 3: 2})
@@ -105,6 +110,28 @@ def test_squared_distance(self):
self.assertEquals(0.0, _squared_distance(dv, dv))
self.assertEquals(0.0, _squared_distance(lst, lst))
+ def test_conversion(self):
+ # numpy arrays should be automatically upcast to float64
+ # tests for fix of [SPARK-5089]
+ v = array([1, 2, 3, 4], dtype='float64')
+ dv = DenseVector(v)
+ self.assertTrue(dv.array.dtype == 'float64')
+ v = array([1, 2, 3, 4], dtype='float32')
+ dv = DenseVector(v)
+ self.assertTrue(dv.array.dtype == 'float64')
+
+ def test_sparse_vector_indexing(self):
+ sv = SparseVector(4, {1: 1, 3: 2})
+ self.assertEquals(sv[0], 0.)
+ self.assertEquals(sv[3], 2.)
+ self.assertEquals(sv[1], 1.)
+ self.assertEquals(sv[2], 0.)
+ self.assertEquals(sv[-1], 2)
+ self.assertEquals(sv[-2], 0)
+ self.assertEquals(sv[-4], 0)
+ for ind in [4, -5, 7.8]:
+ self.assertRaises(ValueError, sv.__getitem__, ind)
+
class ListTests(PySparkTestCase):
@@ -113,7 +140,7 @@ class ListTests(PySparkTestCase):
as NumPy arrays.
"""
- def test_clustering(self):
+ def test_kmeans(self):
from pyspark.mllib.clustering import KMeans
data = [
[0, 1.1],
@@ -125,9 +152,50 @@ def test_clustering(self):
self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1]))
self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3]))
+ def test_kmeans_deterministic(self):
+ from pyspark.mllib.clustering import KMeans
+ X = range(0, 100, 10)
+ Y = range(0, 100, 10)
+ data = [[x, y] for x, y in zip(X, Y)]
+ clusters1 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||", seed=42)
+ clusters2 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||", seed=42)
+ centers1 = clusters1.centers
+ centers2 = clusters2.centers
+ for c1, c2 in zip(centers1, centers2):
+ # TODO: Allow small numeric difference.
+ self.assertTrue(array_equal(c1, c2))
+
+ def test_gmm(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ data = self.sc.parallelize([
+ [1, 2],
+ [8, 9],
+ [-4, -3],
+ [-6, -7],
+ ])
+ clusters = GaussianMixture.train(data, 2, convergenceTol=0.001,
+ maxIterations=100, seed=56)
+ labels = clusters.predict(data).collect()
+ self.assertEquals(labels[0], labels[1])
+ self.assertEquals(labels[2], labels[3])
+
+ def test_gmm_deterministic(self):
+ from pyspark.mllib.clustering import GaussianMixture
+ x = range(0, 100, 10)
+ y = range(0, 100, 10)
+ data = self.sc.parallelize([[a, b] for a, b in zip(x, y)])
+ clusters1 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+ maxIterations=100, seed=63)
+ clusters2 = GaussianMixture.train(data, 5, convergenceTol=0.001,
+ maxIterations=100, seed=63)
+ for c1, c2 in zip(clusters1.weights, clusters2.weights):
+ self.assertEquals(round(c1, 7), round(c2, 7))
+
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
- from pyspark.mllib.tree import DecisionTree
+ from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
data = [
LabeledPoint(0.0, [1, 0, 0]),
LabeledPoint(1.0, [0, 1, 1]),
@@ -156,18 +224,31 @@ def test_classification(self):
self.assertTrue(nb_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories
- dt_model = \
- DecisionTree.trainClassifier(rdd, numClasses=2,
- categoricalFeaturesInfo=categoricalFeaturesInfo)
+ dt_model = DecisionTree.trainClassifier(
+ rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
+ rf_model = RandomForest.trainClassifier(
+ rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100)
+ self.assertTrue(rf_model.predict(features[0]) <= 0)
+ self.assertTrue(rf_model.predict(features[1]) > 0)
+ self.assertTrue(rf_model.predict(features[2]) <= 0)
+ self.assertTrue(rf_model.predict(features[3]) > 0)
+
+ gbt_model = GradientBoostedTrees.trainClassifier(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(gbt_model.predict(features[0]) <= 0)
+ self.assertTrue(gbt_model.predict(features[1]) > 0)
+ self.assertTrue(gbt_model.predict(features[2]) <= 0)
+ self.assertTrue(gbt_model.predict(features[3]) > 0)
+
def test_regression(self):
from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \
RidgeRegressionWithSGD
- from pyspark.mllib.tree import DecisionTree
+ from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees
data = [
LabeledPoint(-1.0, [0, -1]),
LabeledPoint(1.0, [0, 1]),
@@ -196,13 +277,27 @@ def test_regression(self):
self.assertTrue(rr_model.predict(features[3]) > 0)
categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories
- dt_model = \
- DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
+ dt_model = DecisionTree.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
self.assertTrue(dt_model.predict(features[0]) <= 0)
self.assertTrue(dt_model.predict(features[1]) > 0)
self.assertTrue(dt_model.predict(features[2]) <= 0)
self.assertTrue(dt_model.predict(features[3]) > 0)
+ rf_model = RandomForest.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100, seed=1)
+ self.assertTrue(rf_model.predict(features[0]) <= 0)
+ self.assertTrue(rf_model.predict(features[1]) > 0)
+ self.assertTrue(rf_model.predict(features[2]) <= 0)
+ self.assertTrue(rf_model.predict(features[3]) > 0)
+
+ gbt_model = GradientBoostedTrees.trainRegressor(
+ rdd, categoricalFeaturesInfo=categoricalFeaturesInfo)
+ self.assertTrue(gbt_model.predict(features[0]) <= 0)
+ self.assertTrue(gbt_model.predict(features[1]) > 0)
+ self.assertTrue(gbt_model.predict(features[2]) <= 0)
+ self.assertTrue(gbt_model.predict(features[3]) > 0)
+
class StatTests(PySparkTestCase):
# SPARK-4023
@@ -221,6 +316,39 @@ def test_col_with_different_rdds(self):
self.assertEqual(10, summary.count())
+class VectorUDTTests(PySparkTestCase):
+
+ dv0 = DenseVector([])
+ dv1 = DenseVector([1.0, 2.0])
+ sv0 = SparseVector(2, [], [])
+ sv1 = SparseVector(2, [1], [2.0])
+ udt = VectorUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
+ self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
+
+ def test_infer_schema(self):
+ sqlCtx = SQLContext(self.sc)
+ rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
+ srdd = sqlCtx.inferSchema(rdd)
+ schema = srdd.schema()
+ field = [f for f in schema.fields if f.name == "features"][0]
+ self.assertEqual(field.dataType, self.udt)
+ vectors = srdd.map(lambda p: p.features).collect()
+ self.assertEqual(len(vectors), 2)
+ for v in vectors:
+ if isinstance(v, SparseVector):
+ self.assertEqual(v, self.sv1)
+ elif isinstance(v, DenseVector):
+ self.assertEqual(v, self.dv1)
+ else:
+ raise ValueError("expecting a vector but got %r of type %r" % (v, type(v)))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
@@ -363,6 +491,103 @@ def test_regression(self):
self.assertTrue(dt_model.predict(features[3]) > 0)
+class ChiSqTestTests(PySparkTestCase):
+ def test_goodness_of_fit(self):
+ from numpy import inf
+
+ observed = Vectors.dense([4, 6, 5])
+ pearson = Statistics.chiSqTest(observed)
+
+ # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))`
+ self.assertEqual(pearson.statistic, 0.4)
+ self.assertEqual(pearson.degreesOfFreedom, 2)
+ self.assertAlmostEqual(pearson.pValue, 0.8187, 4)
+
+ # Different expected and observed sum
+ observed1 = Vectors.dense([21, 38, 43, 80])
+ expected1 = Vectors.dense([3, 5, 7, 20])
+ pearson1 = Statistics.chiSqTest(observed1, expected1)
+
+ # Results validated against the R command
+ # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))`
+ self.assertAlmostEqual(pearson1.statistic, 14.1429, 4)
+ self.assertEqual(pearson1.degreesOfFreedom, 3)
+ self.assertAlmostEqual(pearson1.pValue, 0.002717, 4)
+
+ # Vectors with different sizes
+ observed3 = Vectors.dense([1.0, 2.0, 3.0])
+ expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0])
+ self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3)
+
+ # Negative counts in observed
+ neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_obs, expected1)
+
+ # Count = 0.0 in expected but not observed
+ zero_expected = Vectors.dense([1.0, 0.0, 3.0])
+ pearson_inf = Statistics.chiSqTest(observed, zero_expected)
+ self.assertEqual(pearson_inf.statistic, inf)
+ self.assertEqual(pearson_inf.degreesOfFreedom, 2)
+ self.assertEqual(pearson_inf.pValue, 0.0)
+
+ # 0.0 in expected and observed simultaneously
+ zero_observed = Vectors.dense([2.0, 0.0, 1.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, zero_observed, zero_expected)
+
+ def test_matrix_independence(self):
+ data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
+ chi = Statistics.chiSqTest(Matrices.dense(3, 4, data))
+
+ # Results validated against R command
+ # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))`
+ self.assertAlmostEqual(chi.statistic, 21.9958, 4)
+ self.assertEqual(chi.degreesOfFreedom, 6)
+ self.assertAlmostEqual(chi.pValue, 0.001213, 4)
+
+ # Negative counts
+ neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_counts)
+
+ # Row sum = 0.0
+ row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, row_zero)
+
+ # Column sum = 0.0
+ col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, col_zero)
+
+ def test_chi_sq_pearson(self):
+ data = [
+ LabeledPoint(0.0, Vectors.dense([0.5, 10.0])),
+ LabeledPoint(0.0, Vectors.dense([1.5, 20.0])),
+ LabeledPoint(1.0, Vectors.dense([1.5, 30.0])),
+ LabeledPoint(0.0, Vectors.dense([3.5, 30.0])),
+ LabeledPoint(0.0, Vectors.dense([3.5, 40.0])),
+ LabeledPoint(1.0, Vectors.dense([3.5, 40.0]))
+ ]
+
+ for numParts in [2, 4, 6, 8]:
+ chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts))
+ feature1 = chi[0]
+ self.assertEqual(feature1.statistic, 0.75)
+ self.assertEqual(feature1.degreesOfFreedom, 2)
+ self.assertAlmostEqual(feature1.pValue, 0.6873, 4)
+
+ feature2 = chi[1]
+ self.assertEqual(feature2.statistic, 1.5)
+ self.assertEqual(feature2.degreesOfFreedom, 3)
+ self.assertAlmostEqual(feature2.pValue, 0.6823, 4)
+
+ def test_right_number_of_results(self):
+ num_cols = 1001
+ sparse_data = [
+ LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])),
+ LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)]))
+ ]
+ chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data))
+ self.assertEqual(len(chi), num_cols)
+ self.assertIsNotNone(chi[1000])
+
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 64ee79d83e849..aae48f213246b 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -15,36 +15,58 @@
# limitations under the License.
#
-from py4j.java_collections import MapConverter
+from __future__ import absolute_import
+
+import random
from pyspark import SparkContext, RDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer
-from pyspark.mllib.linalg import Vector, _convert_to_vector, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
+from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
-__all__ = ['DecisionTreeModel', 'DecisionTree']
+__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel',
+ 'RandomForest', 'GradientBoostedTrees']
-class DecisionTreeModel(object):
+class TreeEnsembleModel(JavaModelWrapper):
+ def predict(self, x):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ if isinstance(x, RDD):
+ return self.call("predict", x.map(_convert_to_vector))
- """
- A decision tree model for classification or regression.
+ else:
+ return self.call("predict", _convert_to_vector(x))
- EXPERIMENTAL: This is an experimental API.
- It will probably be modified for Spark v1.2.
- """
+ def numTrees(self):
+ """
+ Get number of trees in ensemble.
+ """
+ return self.call("numTrees")
- def __init__(self, sc, java_model):
+ def totalNumNodes(self):
"""
- :param sc: Spark context
- :param java_model: Handle to Java model object
+ Get total number of nodes, summed over all trees in the ensemble.
"""
- self._sc = sc
- self._java_model = java_model
+ return self.call("totalNumNodes")
- def __del__(self):
- self._sc._gateway.detach(self._java_model)
+ def __repr__(self):
+ """ Summary of model """
+ return self._java_model.toString()
+ def toDebugString(self):
+ """ Full model """
+ return self._java_model.toDebugString()
+
+
+class DecisionTreeModel(JavaModelWrapper):
+ """
+ .. note:: Experimental
+
+ A decision tree model for classification or regression.
+ """
def predict(self, x):
"""
Predict the label of one or more examples.
@@ -52,24 +74,11 @@ def predict(self, x):
:param x: Data point (feature vector),
or an RDD of data points (feature vectors).
"""
- SerDe = self._sc._jvm.SerDe
- ser = PickleSerializer()
if isinstance(x, RDD):
- # Bulk prediction
- first = x.take(1)
- if not first:
- return self._sc.parallelize([])
- if not isinstance(first[0], Vector):
- x = x.map(_convert_to_vector)
- jPred = self._java_model.predict(_to_java_object_rdd(x)).toJavaRDD()
- jpyrdd = self._sc._jvm.SerDe.javaToPython(jPred)
- return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024))
+ return self.call("predict", x.map(_convert_to_vector))
else:
- # Assume x is a single data point.
- bytes = bytearray(ser.dumps(_convert_to_vector(x)))
- vec = self._sc._jvm.SerDe.loads(bytes)
- return self._java_model.predict(vec)
+ return self.call("predict", _convert_to_vector(x))
def numNodes(self):
return self._java_model.numNodes()
@@ -78,42 +87,32 @@ def depth(self):
return self._java_model.depth()
def __repr__(self):
- """ Print summary of model. """
+ """ summary of model. """
return self._java_model.toString()
def toDebugString(self):
- """ Print full model. """
+ """ full model. """
return self._java_model.toDebugString()
class DecisionTree(object):
-
"""
- Learning algorithm for a decision tree model
- for classification or regression.
-
- EXPERIMENTAL: This is an experimental API.
- It will probably be modified for Spark v1.2.
+ .. note:: Experimental
+ Learning algorithm for a decision tree model for classification or regression.
"""
- @staticmethod
- def _train(data, type, numClasses, categoricalFeaturesInfo,
- impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
- minInfoGain=0.0):
+ @classmethod
+ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32,
+ minInstancesPerNode=1, minInfoGain=0.0):
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
- sc = data.context
- jrdd = _to_java_object_rdd(data)
- cfiMap = MapConverter().convert(categoricalFeaturesInfo,
- sc._gateway._gateway_client)
- model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
- jrdd, type, numClasses, cfiMap,
- impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
- return DecisionTreeModel(sc, model)
-
- @staticmethod
- def trainClassifier(data, numClasses, categoricalFeaturesInfo,
+ model = callMLlibFunc("trainDecisionTreeModel", data, type, numClasses, features,
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+ return DecisionTreeModel(model)
+
+ @classmethod
+ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
@@ -131,8 +130,8 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
- :param minInstancesPerNode: Min number of instances required at child nodes to create
- the parent split
+ :param minInstancesPerNode: Min number of instances required at child
+ nodes to create the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
@@ -157,16 +156,19 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
- >>> model.predict(array([1.0])) > 0
- True
- >>> model.predict(array([0.0])) == 0
- True
+ >>> model.predict(array([1.0]))
+ 1.0
+ >>> model.predict(array([0.0]))
+ 0.0
+ >>> rdd = sc.parallelize([[1.0], [0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
"""
- return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo,
- impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+ return cls._train(data, "classification", numClasses, categoricalFeaturesInfo,
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
- @staticmethod
- def trainRegressor(data, categoricalFeaturesInfo,
+ @classmethod
+ def trainRegressor(cls, data, categoricalFeaturesInfo,
impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
@@ -183,14 +185,13 @@ def trainRegressor(data, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
- :param minInstancesPerNode: Min number of instances required at child nodes to create
- the parent split
+ :param minInstancesPerNode: Min number of instances required at child
+ nodes to create the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
Example usage:
- >>> from numpy import array
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
>>> from pyspark.mllib.linalg import SparseVector
@@ -203,17 +204,312 @@ def trainRegressor(data, categoricalFeaturesInfo,
... ]
>>>
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
- >>> model.predict(array([0.0, 1.0])) == 1
- True
- >>> model.predict(array([0.0, 0.0])) == 0
- True
- >>> model.predict(SparseVector(2, {1: 1.0})) == 1
- True
- >>> model.predict(SparseVector(2, {1: 0.0})) == 0
- True
- """
- return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo,
- impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+ >>> model.predict(SparseVector(2, {1: 1.0}))
+ 1.0
+ >>> model.predict(SparseVector(2, {1: 0.0}))
+ 0.0
+ >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
+ """
+ return cls._train(data, "regression", 0, categoricalFeaturesInfo,
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+
+
+class RandomForestModel(TreeEnsembleModel):
+ """
+ .. note:: Experimental
+
+ Represents a random forest model.
+ """
+
+
+class RandomForest(object):
+ """
+ .. note:: Experimental
+
+ Learning algorithm for a random forest model for classification or regression.
+ """
+
+ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird")
+
+ @classmethod
+ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees,
+ featureSubsetStrategy, impurity, maxDepth, maxBins, seed):
+ first = data.first()
+ assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
+ if featureSubsetStrategy not in cls.supportedFeatureSubsetStrategies:
+ raise ValueError("unsupported featureSubsetStrategy: %s" % featureSubsetStrategy)
+ if seed is None:
+ seed = random.randint(0, 1 << 30)
+ model = callMLlibFunc("trainRandomForestModel", data, algo, numClasses,
+ categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity,
+ maxDepth, maxBins, seed)
+ return RandomForestModel(model)
+
+ @classmethod
+ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
+ featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32,
+ seed=None):
+ """
+ Method to train a decision tree model for binary or multiclass
+ classification.
+
+ :param data: Training dataset: RDD of LabeledPoint. Labels should take
+ values {0, 1, ..., numClasses-1}.
+ :param numClasses: number of classes for classification.
+ :param categoricalFeaturesInfo: Map storing arity of categorical features.
+ E.g., an entry (n -> k) indicates that feature n is categorical
+ with k categories indexed from 0: {0, 1, ..., k-1}.
+ :param numTrees: Number of trees in the random forest.
+ :param featureSubsetStrategy: Number of features to consider for splits at
+ each node.
+ Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ If "auto" is set, this parameter is set based on numTrees:
+ if numTrees == 1, set to "all";
+ if numTrees > 1 (forest) set to "sqrt".
+ :param impurity: Criterion used for information gain calculation.
+ Supported values: "gini" (recommended) or "entropy".
+ :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node;
+ depth 1 means 1 internal node + 2 leaf nodes. (default: 4)
+ :param maxBins: maximum number of bins used for splitting features
+ (default: 100)
+ :param seed: Random seed for bootstrapping and choosing feature subsets.
+ :return: RandomForestModel that can be used for prediction
+
+ Example usage:
+
+ >>> from pyspark.mllib.regression import LabeledPoint
+ >>> from pyspark.mllib.tree import RandomForest
+ >>>
+ >>> data = [
+ ... LabeledPoint(0.0, [0.0]),
+ ... LabeledPoint(0.0, [1.0]),
+ ... LabeledPoint(1.0, [2.0]),
+ ... LabeledPoint(1.0, [3.0])
+ ... ]
+ >>> model = RandomForest.trainClassifier(sc.parallelize(data), 2, {}, 3, seed=42)
+ >>> model.numTrees()
+ 3
+ >>> model.totalNumNodes()
+ 7
+ >>> print model,
+ TreeEnsembleModel classifier with 3 trees
+ >>> print model.toDebugString(),
+ TreeEnsembleModel classifier with 3 trees
+
+ Tree 0:
+ Predict: 1.0
+ Tree 1:
+ If (feature 0 <= 1.0)
+ Predict: 0.0
+ Else (feature 0 > 1.0)
+ Predict: 1.0
+ Tree 2:
+ If (feature 0 <= 1.0)
+ Predict: 0.0
+ Else (feature 0 > 1.0)
+ Predict: 1.0
+ >>> model.predict([2.0])
+ 1.0
+ >>> model.predict([0.0])
+ 0.0
+ >>> rdd = sc.parallelize([[3.0], [1.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
+ """
+ return cls._train(data, "classification", numClasses,
+ categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity,
+ maxDepth, maxBins, seed)
+
+ @classmethod
+ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto",
+ impurity="variance", maxDepth=4, maxBins=32, seed=None):
+ """
+ Method to train a decision tree model for regression.
+
+ :param data: Training dataset: RDD of LabeledPoint. Labels are
+ real numbers.
+ :param categoricalFeaturesInfo: Map storing arity of categorical
+ features. E.g., an entry (n -> k) indicates that feature
+ n is categorical with k categories indexed from 0:
+ {0, 1, ..., k-1}.
+ :param numTrees: Number of trees in the random forest.
+ :param featureSubsetStrategy: Number of features to consider for
+ splits at each node.
+ Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ If "auto" is set, this parameter is set based on numTrees:
+ if numTrees == 1, set to "all";
+ if numTrees > 1 (forest) set to "onethird" for regression.
+ :param impurity: Criterion used for information gain calculation.
+ Supported values: "variance".
+ :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1
+ leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ (default: 4)
+ :param maxBins: maximum number of bins used for splitting features
+ (default: 100)
+ :param seed: Random seed for bootstrapping and choosing feature subsets.
+ :return: RandomForestModel that can be used for prediction
+
+ Example usage:
+
+ >>> from pyspark.mllib.regression import LabeledPoint
+ >>> from pyspark.mllib.tree import RandomForest
+ >>> from pyspark.mllib.linalg import SparseVector
+ >>>
+ >>> sparse_data = [
+ ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
+ ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
+ ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
+ ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
+ ... ]
+ >>>
+ >>> model = RandomForest.trainRegressor(sc.parallelize(sparse_data), {}, 2, seed=42)
+ >>> model.numTrees()
+ 2
+ >>> model.totalNumNodes()
+ 4
+ >>> model.predict(SparseVector(2, {1: 1.0}))
+ 1.0
+ >>> model.predict(SparseVector(2, {0: 1.0}))
+ 0.5
+ >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.5]
+ """
+ return cls._train(data, "regression", 0, categoricalFeaturesInfo, numTrees,
+ featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
+
+
+class GradientBoostedTreesModel(TreeEnsembleModel):
+ """
+ .. note:: Experimental
+
+ Represents a gradient-boosted tree model.
+ """
+
+
+class GradientBoostedTrees(object):
+ """
+ .. note:: Experimental
+
+ Learning algorithm for a gradient boosted trees model for classification or regression.
+ """
+
+ @classmethod
+ def _train(cls, data, algo, categoricalFeaturesInfo,
+ loss, numIterations, learningRate, maxDepth):
+ first = data.first()
+ assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
+ model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo,
+ loss, numIterations, learningRate, maxDepth)
+ return GradientBoostedTreesModel(model)
+
+ @classmethod
+ def trainClassifier(cls, data, categoricalFeaturesInfo,
+ loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3):
+ """
+ Method to train a gradient-boosted trees model for classification.
+
+ :param data: Training dataset: RDD of LabeledPoint. Labels should take values {0, 1}.
+ :param categoricalFeaturesInfo: Map storing arity of categorical
+ features. E.g., an entry (n -> k) indicates that feature
+ n is categorical with k categories indexed from 0:
+ {0, 1, ..., k-1}.
+ :param loss: Loss function used for minimization during gradient boosting.
+ Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}.
+ :param numIterations: Number of iterations of boosting.
+ (default: 100)
+ :param learningRate: Learning rate for shrinking the contribution of each estimator.
+ The learning rate should be between in the interval (0, 1]
+ (default: 0.1)
+ :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1
+ leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ (default: 3)
+ :return: GradientBoostedTreesModel that can be used for prediction
+
+ Example usage:
+
+ >>> from pyspark.mllib.regression import LabeledPoint
+ >>> from pyspark.mllib.tree import GradientBoostedTrees
+ >>>
+ >>> data = [
+ ... LabeledPoint(0.0, [0.0]),
+ ... LabeledPoint(0.0, [1.0]),
+ ... LabeledPoint(1.0, [2.0]),
+ ... LabeledPoint(1.0, [3.0])
+ ... ]
+ >>>
+ >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {})
+ >>> model.numTrees()
+ 100
+ >>> model.totalNumNodes()
+ 300
+ >>> print model, # it already has newline
+ TreeEnsembleModel classifier with 100 trees
+ >>> model.predict([2.0])
+ 1.0
+ >>> model.predict([0.0])
+ 0.0
+ >>> rdd = sc.parallelize([[2.0], [0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
+ """
+ return cls._train(data, "classification", categoricalFeaturesInfo,
+ loss, numIterations, learningRate, maxDepth)
+
+ @classmethod
+ def trainRegressor(cls, data, categoricalFeaturesInfo,
+ loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3):
+ """
+ Method to train a gradient-boosted trees model for regression.
+
+ :param data: Training dataset: RDD of LabeledPoint. Labels are
+ real numbers.
+ :param categoricalFeaturesInfo: Map storing arity of categorical
+ features. E.g., an entry (n -> k) indicates that feature
+ n is categorical with k categories indexed from 0:
+ {0, 1, ..., k-1}.
+ :param loss: Loss function used for minimization during gradient boosting.
+ Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}.
+ :param numIterations: Number of iterations of boosting.
+ (default: 100)
+ :param learningRate: Learning rate for shrinking the contribution of each estimator.
+ The learning rate should be between in the interval (0, 1]
+ (default: 0.1)
+ :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1
+ leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ (default: 3)
+ :return: GradientBoostedTreesModel that can be used for prediction
+
+ Example usage:
+
+ >>> from pyspark.mllib.regression import LabeledPoint
+ >>> from pyspark.mllib.tree import GradientBoostedTrees
+ >>> from pyspark.mllib.linalg import SparseVector
+ >>>
+ >>> sparse_data = [
+ ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
+ ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
+ ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
+ ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
+ ... ]
+ >>>
+ >>> model = GradientBoostedTrees.trainRegressor(sc.parallelize(sparse_data), {})
+ >>> model.numTrees()
+ 100
+ >>> model.totalNumNodes()
+ 102
+ >>> model.predict(SparseVector(2, {1: 1.0}))
+ 1.0
+ >>> model.predict(SparseVector(2, {0: 1.0}))
+ 0.0
+ >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
+ """
+ return cls._train(data, "regression", categoricalFeaturesInfo,
+ loss, numIterations, learningRate, maxDepth)
def _test():
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 84b39a48619d2..4ed978b45409c 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -18,8 +18,7 @@
import numpy as np
import warnings
-from pyspark.rdd import RDD
-from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
+from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
@@ -162,20 +161,11 @@ def loadLabeledPoints(sc, path, minPartitions=None):
>>> tempFile = NamedTemporaryFile(delete=True)
>>> tempFile.close()
>>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name)
- >>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect()
- >>> type(loaded[0]) == LabeledPoint
- True
- >>> print examples[0]
- (1.1,(3,[0,2],[-1.23,4.56e-07]))
- >>> type(examples[1]) == LabeledPoint
- True
- >>> print examples[1]
- (0.0,[1.01,2.02,3.03])
+ >>> MLUtils.loadLabeledPoints(sc, tempFile.name).collect()
+ [LabeledPoint(1.1, (3,[0,2],[-1.23,4.56e-07])), LabeledPoint(0.0, [1.01,2.02,3.03])]
"""
minPartitions = minPartitions or min(sc.defaultParallelism, 2)
- jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions)
- jpyrdd = sc._jvm.SerDe.javaToPython(jrdd)
- return RDD(jpyrdd, sc, AutoBatchedSerializer(PickleSerializer()))
+ return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
def _test():
diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py
new file mode 100644
index 0000000000000..4408996db0790
--- /dev/null
+++ b/python/pyspark/profiler.py
@@ -0,0 +1,172 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import cProfile
+import pstats
+import os
+import atexit
+
+from pyspark.accumulators import AccumulatorParam
+
+
+class ProfilerCollector(object):
+ """
+ This class keeps track of different profilers on a per
+ stage basis. Also this is used to create new profilers for
+ the different stages.
+ """
+
+ def __init__(self, profiler_cls, dump_path=None):
+ self.profiler_cls = profiler_cls
+ self.profile_dump_path = dump_path
+ self.profilers = []
+
+ def new_profiler(self, ctx):
+ """ Create a new profiler using class `profiler_cls` """
+ return self.profiler_cls(ctx)
+
+ def add_profiler(self, id, profiler):
+ """ Add a profiler for RDD `id` """
+ if not self.profilers:
+ if self.profile_dump_path:
+ atexit.register(self.dump_profiles, self.profile_dump_path)
+ else:
+ atexit.register(self.show_profiles)
+
+ self.profilers.append([id, profiler, False])
+
+ def dump_profiles(self, path):
+ """ Dump the profile stats into directory `path` """
+ for id, profiler, _ in self.profilers:
+ profiler.dump(id, path)
+ self.profilers = []
+
+ def show_profiles(self):
+ """ Print the profile stats to stdout """
+ for i, (id, profiler, showed) in enumerate(self.profilers):
+ if not showed and profiler:
+ profiler.show(id)
+ # mark it as showed
+ self.profilers[i][2] = True
+
+
+class Profiler(object):
+ """
+ .. note:: DeveloperApi
+
+ PySpark supports custom profilers, this is to allow for different profilers to
+ be used as well as outputting to different formats than what is provided in the
+ BasicProfiler.
+
+ A custom profiler has to define or inherit the following methods:
+ profile - will produce a system profile of some sort.
+ stats - return the collected stats.
+ dump - dumps the profiles to a path
+ add - adds a profile to the existing accumulated profile
+
+ The profiler class is chosen when creating a SparkContext
+
+ >>> from pyspark import SparkConf, SparkContext
+ >>> from pyspark import BasicProfiler
+ >>> class MyCustomProfiler(BasicProfiler):
+ ... def show(self, id):
+ ... print "My custom profiles for RDD:%s" % id
+ ...
+ >>> conf = SparkConf().set("spark.python.profile", "true")
+ >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler)
+ >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+ >>> sc.show_profiles()
+ My custom profiles for RDD:1
+ My custom profiles for RDD:2
+ >>> sc.stop()
+ """
+
+ def __init__(self, ctx):
+ pass
+
+ def profile(self, func):
+ """ Do profiling on the function `func`"""
+ raise NotImplemented
+
+ def stats(self):
+ """ Return the collected profiling stats (pstats.Stats)"""
+ raise NotImplemented
+
+ def show(self, id):
+ """ Print the profile stats to stdout, id is the RDD id """
+ stats = self.stats()
+ if stats:
+ print "=" * 60
+ print "Profile of RDD" % id
+ print "=" * 60
+ stats.sort_stats("time", "cumulative").print_stats()
+
+ def dump(self, id, path):
+ """ Dump the profile into path, id is the RDD id """
+ if not os.path.exists(path):
+ os.makedirs(path)
+ stats = self.stats()
+ if stats:
+ p = os.path.join(path, "rdd_%d.pstats" % id)
+ stats.dump_stats(p)
+
+
+class PStatsParam(AccumulatorParam):
+ """PStatsParam is used to merge pstats.Stats"""
+
+ @staticmethod
+ def zero(value):
+ return None
+
+ @staticmethod
+ def addInPlace(value1, value2):
+ if value1 is None:
+ return value2
+ value1.add(value2)
+ return value1
+
+
+class BasicProfiler(Profiler):
+ """
+ BasicProfiler is the default profiler, which is implemented based on
+ cProfile and Accumulator
+ """
+ def __init__(self, ctx):
+ Profiler.__init__(self, ctx)
+ # Creates a new accumulator for combining the profiles of different
+ # partitions of a stage
+ self._accumulator = ctx.accumulator(None, PStatsParam)
+
+ def profile(self, func):
+ """ Runs and profiles the method to_profile passed in. A profile object is returned. """
+ pr = cProfile.Profile()
+ pr.runcall(func)
+ st = pstats.Stats(pr)
+ st.stream = None # make it picklable
+ st.strip_dirs()
+
+ # Adds a new profile to the existing accumulated value
+ self._accumulator.add(st)
+
+ def stats(self):
+ return self._accumulator.value
+
+
+if __name__ == "__main__":
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 15be4bfec92f9..6e029bf7f13fc 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -28,17 +28,16 @@
import warnings
import heapq
import bisect
-from random import Random
-from math import sqrt, log, isinf, isnan
+import random
+from math import sqrt, log, isinf, isnan, pow, ceil
-from pyspark.accumulators import PStatsParam
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
-from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
+from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -120,7 +119,7 @@ class RDD(object):
operated on in parallel.
"""
- def __init__(self, jrdd, ctx, jrdd_deserializer):
+ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())):
self._jrdd = jrdd
self.is_cached = False
self.is_checkpointed = False
@@ -129,12 +128,8 @@ def __init__(self, jrdd, ctx, jrdd_deserializer):
self._id = jrdd.id()
self._partitionFunc = None
- def _toPickleSerialization(self):
- if (self._jrdd_deserializer == PickleSerializer() or
- self._jrdd_deserializer == BatchedSerializer(PickleSerializer())):
- return self
- else:
- return self._reserialize(BatchedSerializer(PickleSerializer(), 10))
+ def _pickled(self):
+ return self._reserialize(AutoBatchedSerializer(PickleSerializer()))
def id(self):
"""
@@ -145,6 +140,17 @@ def id(self):
def __repr__(self):
return self._jrdd.toString()
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle an RDD, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to broadcast an RDD or reference an RDD from an "
+ "action or transformation. RDD transformations and actions can only be invoked by the "
+ "driver, not inside of other transformations; for example, "
+ "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values "
+ "transformation and count action cannot be performed inside of the rdd1.map "
+ "transformation. For more information, see SPARK-5063."
+ )
+
@property
def context(self):
"""
@@ -314,20 +320,43 @@ def distinct(self, numPartitions=None):
def sample(self, withReplacement, fraction, seed=None):
"""
- Return a sampled subset of this RDD (relies on numpy and falls back
- on default random generator if numpy is unavailable).
+ Return a sampled subset of this RDD.
- >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
- [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
+ >>> rdd = sc.parallelize(range(100), 4)
+ >>> rdd.sample(False, 0.1, 81).count()
+ 10
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
+ def randomSplit(self, weights, seed=None):
+ """
+ Randomly splits this RDD with the provided weights.
+
+ :param weights: weights for splits, will be normalized if they don't sum to 1
+ :param seed: random seed
+ :return: split RDDs in a list
+
+ >>> rdd = sc.parallelize(range(5), 1)
+ >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
+ >>> rdd1.collect()
+ [1, 3]
+ >>> rdd2.collect()
+ [0, 2, 4]
+ """
+ s = float(sum(weights))
+ cweights = [0.0]
+ for w in weights:
+ cweights.append(cweights[-1] + w / s)
+ if seed is None:
+ seed = random.randint(0, 2 ** 32 - 1)
+ return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
+ for lb, ub in zip(cweights, cweights[1:])]
+
# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed=None):
"""
- Return a fixed-size sampled subset of this RDD (currently requires
- numpy).
+ Return a fixed-size sampled subset of this RDD.
>>> rdd = sc.parallelize(range(0, 10))
>>> len(rdd.takeSample(True, 20, 1))
@@ -348,7 +377,7 @@ def takeSample(self, withReplacement, num, seed=None):
if initialCount == 0:
return []
- rand = Random(seed)
+ rand = random.Random(seed)
if (not withReplacement) and num >= initialCount:
# shuffle current RDD and return
@@ -449,12 +478,10 @@ def intersection(self, other):
def _reserialize(self, serializer=None):
serializer = serializer or self.ctx.serializer
- if self._jrdd_deserializer == serializer:
- return self
- else:
- converted = self.map(lambda x: x, preservesPartitioning=True)
- converted._jrdd_deserializer = serializer
- return converted
+ if self._jrdd_deserializer != serializer:
+ self = self.map(lambda x: x, preservesPartitioning=True)
+ self._jrdd_deserializer = serializer
+ return self
def __add__(self, other):
"""
@@ -529,6 +556,8 @@ def sortPartition(iterator):
# the key-space into bins such that the bins have roughly the same
# number of (key, value) pairs falling into them
rddSize = self.count()
+ if not rddSize:
+ return self # empty RDD
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
@@ -697,6 +726,43 @@ def func(iterator):
return reduce(f, vals)
raise ValueError("Can not reduce() empty RDD")
+ def treeReduce(self, f, depth=2):
+ """
+ Reduces the elements of this RDD in a multi-level tree pattern.
+
+ :param depth: suggested depth of the tree (default: 2)
+
+ >>> add = lambda x, y: x + y
+ >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+ >>> rdd.treeReduce(add)
+ -5
+ >>> rdd.treeReduce(add, 1)
+ -5
+ >>> rdd.treeReduce(add, 2)
+ -5
+ >>> rdd.treeReduce(add, 5)
+ -5
+ >>> rdd.treeReduce(add, 10)
+ -5
+ """
+ if depth < 1:
+ raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+ zeroValue = None, True # Use the second entry to indicate whether this is a dummy value.
+
+ def op(x, y):
+ if x[1]:
+ return y
+ elif y[1]:
+ return x
+ else:
+ return f(x[0], y[0]), False
+
+ reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth)
+ if reduced[1]:
+ raise ValueError("Cannot reduce empty RDD.")
+ return reduced[0]
+
def fold(self, zeroValue, op):
"""
Aggregate the elements of each partition, and then the results for all
@@ -748,6 +814,58 @@ def func(iterator):
return self.mapPartitions(func).fold(zeroValue, combOp)
+ def treeAggregate(self, zeroValue, seqOp, combOp, depth=2):
+ """
+ Aggregates the elements of this RDD in a multi-level tree
+ pattern.
+
+ :param depth: suggested depth of the tree (default: 2)
+
+ >>> add = lambda x, y: x + y
+ >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10)
+ >>> rdd.treeAggregate(0, add, add)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 1)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 2)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 5)
+ -5
+ >>> rdd.treeAggregate(0, add, add, 10)
+ -5
+ """
+ if depth < 1:
+ raise ValueError("Depth cannot be smaller than 1 but got %d." % depth)
+
+ if self.getNumPartitions() == 0:
+ return zeroValue
+
+ def aggregatePartition(iterator):
+ acc = zeroValue
+ for obj in iterator:
+ acc = seqOp(acc, obj)
+ yield acc
+
+ partiallyAggregated = self.mapPartitions(aggregatePartition)
+ numPartitions = partiallyAggregated.getNumPartitions()
+ scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2)
+ # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree
+ # aggregation.
+ while numPartitions > scale + numPartitions / scale:
+ numPartitions /= scale
+ curNumPartitions = numPartitions
+
+ def mapPartition(i, iterator):
+ for obj in iterator:
+ yield (i % curNumPartitions, obj)
+
+ partiallyAggregated = partiallyAggregated \
+ .mapPartitionsWithIndex(mapPartition) \
+ .reduceByKey(combOp, curNumPartitions) \
+ .values()
+
+ return partiallyAggregated.reduce(combOp)
+
def max(self, key=None):
"""
Find the maximum item in this RDD.
@@ -1111,6 +1229,18 @@ def first(self):
return rs[0]
raise ValueError("RDD is empty")
+ def isEmpty(self):
+ """
+ Returns true if and only if the RDD contains no elements at all. Note that an RDD
+ may be empty even when it has at least 1 partition.
+
+ >>> sc.parallelize([]).isEmpty()
+ True
+ >>> sc.parallelize([1]).isEmpty()
+ False
+ """
+ return self._jrdd.partitions().size() == 0 or len(self.take(1)) == 0
+
def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
"""
Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file
@@ -1123,9 +1253,8 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None
:param valueConverter: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf,
keyConverter, valueConverter, True)
def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
@@ -1150,9 +1279,8 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl
:param conf: Hadoop job configuration, passed in as a dict (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, True, path,
outputFormatClass,
keyClass, valueClass,
keyConverter, valueConverter, jconf)
@@ -1169,9 +1297,8 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
:param valueConverter: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf,
keyConverter, valueConverter, False)
def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
@@ -1198,9 +1325,8 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No
:param compressionCodecClass: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, True, path,
outputFormatClass,
keyClass, valueClass,
keyConverter, valueConverter,
@@ -1218,9 +1344,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None):
:param path: path to sequence file
:param compressionCodecClass: (None by default)
"""
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, True,
path, compressionCodecClass)
def saveAsPickleFile(self, path, batchSize=10):
@@ -1235,8 +1360,11 @@ def saveAsPickleFile(self, path, batchSize=10):
>>> sorted(sc.pickleFile(tmpFile.name, 5).collect())
[1, 2, 'rdd', 'spark']
"""
- self._reserialize(BatchedSerializer(PickleSerializer(),
- batchSize))._jrdd.saveAsObjectFile(path)
+ if batchSize == 0:
+ ser = AutoBatchedSerializer(PickleSerializer())
+ else:
+ ser = BatchedSerializer(PickleSerializer(), batchSize)
+ self._reserialize(ser)._jrdd.saveAsObjectFile(path)
def saveAsTextFile(self, path):
"""
@@ -1594,8 +1722,8 @@ def groupByKey(self, numPartitions=None):
Hash-partitions the resulting RDD with into numPartitions partitions.
Note: If you are grouping in order to perform an aggregation (such as a
- sum or average) over each key, using reduceByKey will provide much
- better performance.
+ sum or average) over each key, using reduceByKey or aggregateByKey will
+ provide much better performance.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
@@ -1777,28 +1905,27 @@ def zip(self, other):
>>> x.zip(y).collect()
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
"""
- if self.getNumPartitions() != other.getNumPartitions():
- raise ValueError("Can only zip with RDD which has the same number of partitions")
-
def get_batch_size(ser):
if isinstance(ser, BatchedSerializer):
return ser.batchSize
- return 0
+ return 1 # not batched
def batch_as(rdd, batchSize):
- ser = rdd._jrdd_deserializer
- if isinstance(ser, BatchedSerializer):
- ser = ser.serializer
- return rdd._reserialize(BatchedSerializer(ser, batchSize))
+ return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize))
my_batch = get_batch_size(self._jrdd_deserializer)
other_batch = get_batch_size(other._jrdd_deserializer)
if my_batch != other_batch:
- # use the greatest batchSize to batch the other one.
- if my_batch > other_batch:
- other = batch_as(other, my_batch)
- else:
- self = batch_as(self, other_batch)
+ # use the smallest batchSize for both of them
+ batchSize = min(my_batch, other_batch)
+ if batchSize <= 0:
+ # auto batched or unlimited
+ batchSize = 100
+ other = batch_as(other, batchSize)
+ self = batch_as(self, batchSize)
+
+ if self.getNumPartitions() != other.getNumPartitions():
+ raise ValueError("Can only zip with RDD which has the same number of partitions")
# There will be an Exception in JVM if there are different number
# of items in each partitions.
@@ -1867,11 +1994,11 @@ def setName(self, name):
Assign a name to this RDD.
>>> rdd1 = sc.parallelize([1,2])
- >>> rdd1.setName('RDD1')
- >>> rdd1.name()
+ >>> rdd1.setName('RDD1').name()
'RDD1'
"""
self._jrdd.setName(name)
+ return self
def toDebugString(self):
"""
@@ -1937,29 +2064,18 @@ def lookup(self, key):
return values.collect()
- def _is_pickled(self):
- """ Return this RDD is serialized by Pickle or not. """
- der = self._jrdd_deserializer
- if isinstance(der, PickleSerializer):
- return True
- if isinstance(der, BatchedSerializer) and isinstance(der.serializer, PickleSerializer):
- return True
- return False
-
def _to_java_object_rdd(self):
""" Return an JavaRDD of Object by unpickling
It will convert each Python object into Java object by Pyrolite, whenever the
RDD is serialized in batch or not.
"""
- rdd = self._reserialize(AutoBatchedSerializer(PickleSerializer())) \
- if not self._is_pickled() else self
- is_batch = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
- return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, is_batch)
+ rdd = self._pickled()
+ return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True)
def countApprox(self, timeout, confidence=0.95):
"""
- :: Experimental ::
+ .. note:: Experimental
Approximate version of count() that returns a potentially incomplete
result within a timeout, even if not all tasks have finished.
@@ -1972,7 +2088,7 @@ def countApprox(self, timeout, confidence=0.95):
def sumApprox(self, timeout, confidence=0.95):
"""
- :: Experimental ::
+ .. note:: Experimental
Approximate operation to return the sum within a timeout
or meet the confidence.
@@ -1988,7 +2104,7 @@ def sumApprox(self, timeout, confidence=0.95):
def meanApprox(self, timeout, confidence=0.95):
"""
- :: Experimental ::
+ .. note:: Experimental
Approximate operation to return the mean within a timeout
or meet the confidence.
@@ -2004,7 +2120,7 @@ def meanApprox(self, timeout, confidence=0.95):
def countApproxDistinct(self, relativeSD=0.05):
"""
- :: Experimental ::
+ .. note:: Experimental
Return approximate number of distinct elements in the RDD.
The algorithm used is based on streamlib's implementation of
@@ -2031,6 +2147,39 @@ def countApproxDistinct(self, relativeSD=0.05):
hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF)
return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD)
+ def toLocalIterator(self):
+ """
+ Return an iterator that contains all of the elements in this RDD.
+ The iterator will consume as much memory as the largest partition in this RDD.
+ >>> rdd = sc.parallelize(range(10))
+ >>> [x for x in rdd.toLocalIterator()]
+ [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+ """
+ partitions = xrange(self.getNumPartitions())
+ for partition in partitions:
+ rows = self.context.runJob(self, lambda x: x, [partition])
+ for row in rows:
+ yield row
+
+
+def _prepare_for_python_RDD(sc, command, obj=None):
+ # the serialized command will be compressed by broadcast
+ ser = CloudPickleSerializer()
+ pickled_command = ser.dumps(command)
+ if len(pickled_command) > (1 << 20): # 1M
+ broadcast = sc.broadcast(pickled_command)
+ pickled_command = ser.dumps(broadcast)
+ # tracking the life cycle by obj
+ if obj is not None:
+ obj._broadcast = broadcast
+ broadcast_vars = ListConverter().convert(
+ [x._jbroadcast for x in sc._pickled_broadcast_vars],
+ sc._gateway._gateway_client)
+ sc._pickled_broadcast_vars.clear()
+ env = MapConverter().convert(sc.environment, sc._gateway._gateway_client)
+ includes = ListConverter().convert(sc._python_includes, sc._gateway._gateway_client)
+ return pickled_command, broadcast_vars, env, includes
+
class PipelinedRDD(RDD):
@@ -2090,34 +2239,25 @@ def _jrdd(self):
return self._jrdd_val
if self._bypass_serializer:
self._jrdd_deserializer = NoOpSerializer()
- enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true"
- profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None
- command = (self.func, profileStats, self._prev_jrdd_deserializer,
+
+ if self.ctx.profiler_collector:
+ profiler = self.ctx.profiler_collector.new_profiler(self.ctx)
+ else:
+ profiler = None
+
+ command = (self.func, profiler, self._prev_jrdd_deserializer,
self._jrdd_deserializer)
- # the serialized command will be compressed by broadcast
- ser = CloudPickleSerializer()
- pickled_command = ser.dumps(command)
- if len(pickled_command) > (1 << 20): # 1M
- self._broadcast = self.ctx.broadcast(pickled_command)
- pickled_command = ser.dumps(self._broadcast)
- broadcast_vars = ListConverter().convert(
- [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
- self.ctx._gateway._gateway_client)
- self.ctx._pickled_broadcast_vars.clear()
- env = MapConverter().convert(self.ctx.environment,
- self.ctx._gateway._gateway_client)
- includes = ListConverter().convert(self.ctx._python_includes,
- self.ctx._gateway._gateway_client)
+ pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self.ctx, command, self)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- bytearray(pickled_command),
+ bytearray(pickled_cmd),
env, includes, self.preservesPartitioning,
self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator)
+ bvars, self.ctx._javaAccumulator)
self._jrdd_val = python_rdd.asJavaRDD()
- if enable_profile:
+ if profiler:
self._id = self._jrdd_val.id()
- self.ctx._add_profile(self._id, profileStats)
+ self.ctx.profiler_collector.add_profiler(self._id, profiler)
return self._jrdd_val
def id(self):
@@ -2135,7 +2275,7 @@ def _test():
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 528a181e8905a..459e1427803cb 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -17,82 +17,48 @@
import sys
import random
+import math
class RDDSamplerBase(object):
def __init__(self, withReplacement, seed=None):
- try:
- import numpy
- self._use_numpy = True
- except ImportError:
- print >> sys.stderr, (
- "NumPy does not appear to be installed. "
- "Falling back to default random generator for sampling.")
- self._use_numpy = False
-
- self._seed = seed if seed is not None else random.randint(0, 2 ** 32 - 1)
+ self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._withReplacement = withReplacement
self._random = None
- self._split = None
- self._rand_initialized = False
def initRandomGenerator(self, split):
- if self._use_numpy:
- import numpy
- self._random = numpy.random.RandomState(self._seed)
+ self._random = random.Random(self._seed ^ split)
+
+ # mixing because the initial seeds are close to each other
+ for _ in xrange(10):
+ self._random.randint(0, 1)
+
+ def getUniformSample(self):
+ return self._random.random()
+
+ def getPoissonSample(self, mean):
+ # Using Knuth's algorithm described in
+ # http://en.wikipedia.org/wiki/Poisson_distribution
+ if mean < 20.0:
+ # one exp and k+1 random calls
+ l = math.exp(-mean)
+ p = self._random.random()
+ k = 0
+ while p > l:
+ k += 1
+ p *= self._random.random()
else:
- self._random = random.Random(self._seed)
+ # switch to the log domain, k+1 expovariate (random + log) calls
+ p = self._random.expovariate(mean)
+ k = 0
+ while p < 1.0:
+ k += 1
+ p += self._random.expovariate(mean)
+ return k
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- self._random.randint(0, 2 ** 32 - 1)
-
- self._split = split
- self._rand_initialized = True
-
- def getUniformSample(self, split):
- if not self._rand_initialized or split != self._split:
- self.initRandomGenerator(split)
-
- if self._use_numpy:
- return self._random.random_sample()
- else:
- return self._random.uniform(0.0, 1.0)
-
- def getPoissonSample(self, split, mean):
- if not self._rand_initialized or split != self._split:
- self.initRandomGenerator(split)
-
- if self._use_numpy:
- return self._random.poisson(mean)
- else:
- # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by
- # drawing a sequence of numbers delta_j ~ Exp(mean)
- num_arrivals = 1
- cur_time = 0.0
-
- cur_time += self._random.expovariate(mean)
-
- if cur_time > 1.0:
- return 0
-
- while(cur_time <= 1.0):
- cur_time += self._random.expovariate(mean)
- num_arrivals += 1
-
- return (num_arrivals - 1)
-
- def shuffle(self, vals):
- if self._random is None:
- self.initRandomGenerator(0) # this should only ever called on the master so
- # the split does not matter
-
- if self._use_numpy:
- self._random.shuffle(vals)
- else:
- self._random.shuffle(vals, self._random.random)
+ def func(self, split, iterator):
+ raise NotImplementedError
class RDDSampler(RDDSamplerBase):
@@ -102,20 +68,35 @@ def __init__(self, withReplacement, fraction, seed=None):
self._fraction = fraction
def func(self, split, iterator):
+ self.initRandomGenerator(split)
if self._withReplacement:
for obj in iterator:
# For large datasets, the expected number of occurrences of each element in
# a sample with replacement is Poisson(frac). We use that to get a count for
# each element.
- count = self.getPoissonSample(split, mean=self._fraction)
+ count = self.getPoissonSample(self._fraction)
for _ in range(0, count):
yield obj
else:
for obj in iterator:
- if self.getUniformSample(split) <= self._fraction:
+ if self.getUniformSample() < self._fraction:
yield obj
+class RDDRangeSampler(RDDSamplerBase):
+
+ def __init__(self, lowerBound, upperBound, seed=None):
+ RDDSamplerBase.__init__(self, False, seed)
+ self._lowerBound = lowerBound
+ self._upperBound = upperBound
+
+ def func(self, split, iterator):
+ self.initRandomGenerator(split)
+ for obj in iterator:
+ if self._lowerBound <= self.getUniformSample() < self._upperBound:
+ yield obj
+
+
class RDDStratifiedSampler(RDDSamplerBase):
def __init__(self, withReplacement, fractions, seed=None):
@@ -123,15 +104,16 @@ def __init__(self, withReplacement, fractions, seed=None):
self._fractions = fractions
def func(self, split, iterator):
+ self.initRandomGenerator(split)
if self._withReplacement:
for key, val in iterator:
# For large datasets, the expected number of occurrences of each element in
# a sample with replacement is Poisson(frac). We use that to get a count for
# each element.
- count = self.getPoissonSample(split, mean=self._fractions[key])
+ count = self.getPoissonSample(self._fractions[key])
for _ in range(0, count):
yield key, val
else:
for key, val in iterator:
- if self.getUniformSample(split) <= self._fractions[key]:
+ if self.getUniformSample() < self._fractions[key]:
yield key, val
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 08a0f0d8ffb3e..0ffb41d02f6f6 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -33,9 +33,8 @@
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> sc.stop()
-By default, PySpark serialize objects in batches; the batch size can be
-controlled through SparkContext's C{batchSize} parameter
-(the default size is 1024 objects):
+PySpark serialize objects in batches; By default, the batch size is chosen based
+on the size of objects, also configurable by SparkContext's C{batchSize} parameter:
>>> sc = SparkContext('local', 'test', batchSize=2)
>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
@@ -48,16 +47,6 @@
>>> rdd._jrdd.count()
8L
>>> sc.stop()
-
-A batch size of -1 uses an unlimited batch size, and a size of 1 disables
-batching:
-
->>> sc = SparkContext('local', 'test', batchSize=1)
->>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
->>> rdd.glom().collect()
-[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
->>> rdd._jrdd.count()
-16L
"""
import cPickle
@@ -73,13 +62,15 @@
from pyspark import cloudpickle
-__all__ = ["PickleSerializer", "MarshalSerializer"]
+__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
class SpecialLengths(object):
END_OF_DATA_SECTION = -1
PYTHON_EXCEPTION_THROWN = -2
TIMING_DATA = -3
+ END_OF_STREAM = -4
+ NULL = -5
class Serializer(object):
@@ -112,7 +103,7 @@ def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
- return "<%s object>" % self.__class__.__name__
+ return "%s()" % self.__class__.__name__
def __hash__(self):
return hash(str(self))
@@ -143,6 +134,10 @@ def load_stream(self, stream):
def _write_with_length(self, obj, stream):
serialized = self.dumps(obj)
+ if serialized is None:
+ raise ValueError("serialized value should not be None")
+ if len(serialized) > (1 << 31):
+ raise ValueError("can not serialize object larger than 2G")
write_int(len(serialized), stream)
if self._only_write_strings:
stream.write(str(serialized))
@@ -153,8 +148,10 @@ def _read_with_length(self, stream):
length = read_int(stream)
if length == SpecialLengths.END_OF_DATA_SECTION:
raise EOFError
+ elif length == SpecialLengths.NULL:
+ return None
obj = stream.read(length)
- if obj == "":
+ if len(obj) < length:
raise EOFError
return self.loads(obj)
@@ -180,6 +177,7 @@ class BatchedSerializer(Serializer):
"""
UNLIMITED_BATCH_SIZE = -1
+ UNKNOWN_BATCH_SIZE = 0
def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
self.serializer = serializer
@@ -188,6 +186,10 @@ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
def _batched(self, iterator):
if self.batchSize == self.UNLIMITED_BATCH_SIZE:
yield list(iterator)
+ elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"):
+ n = len(iterator)
+ for i in xrange(0, n, self.batchSize):
+ yield iterator[i: i + self.batchSize]
else:
items = []
count = 0
@@ -212,10 +214,10 @@ def _load_stream_without_unbatching(self, stream):
def __eq__(self, other):
return (isinstance(other, BatchedSerializer) and
- other.serializer == self.serializer)
+ other.serializer == self.serializer and other.batchSize == self.batchSize)
def __repr__(self):
- return "BatchedSerializer<%s>" % str(self.serializer)
+ return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
class AutoBatchedSerializer(BatchedSerializer):
@@ -224,7 +226,7 @@ class AutoBatchedSerializer(BatchedSerializer):
"""
def __init__(self, serializer, bestSize=1 << 16):
- BatchedSerializer.__init__(self, serializer, -1)
+ BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE)
self.bestSize = bestSize
def dump_stream(self, iterator, stream):
@@ -247,10 +249,10 @@ def dump_stream(self, iterator, stream):
def __eq__(self, other):
return (isinstance(other, AutoBatchedSerializer) and
- other.serializer == self.serializer)
+ other.serializer == self.serializer and other.bestSize == self.bestSize)
def __str__(self):
- return "AutoBatchedSerializer<%s>" % str(self.serializer)
+ return "AutoBatchedSerializer(%s)" % str(self.serializer)
class CartesianDeserializer(FramedSerializer):
@@ -283,7 +285,7 @@ def __eq__(self, other):
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
def __repr__(self):
- return "CartesianDeserializer<%s, %s>" % \
+ return "CartesianDeserializer(%s, %s)" % \
(str(self.key_ser), str(self.val_ser))
@@ -310,7 +312,7 @@ def __eq__(self, other):
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
def __repr__(self):
- return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser))
+ return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
class NoOpSerializer(FramedSerializer):
@@ -429,7 +431,7 @@ def loads(self, obj):
class AutoSerializer(FramedSerializer):
"""
- Choose marshal or cPickle as serialization protocol autumatically
+ Choose marshal or cPickle as serialization protocol automatically
"""
def __init__(self):
@@ -459,9 +461,9 @@ class CompressedSerializer(FramedSerializer):
"""
Compress the serialized data
"""
-
def __init__(self, serializer):
FramedSerializer.__init__(self)
+ assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer"
self.serializer = serializer
def dumps(self, obj):
@@ -470,6 +472,9 @@ def dumps(self, obj):
def loads(self, obj):
return self.serializer.loads(zlib.decompress(obj))
+ def __eq__(self, other):
+ return isinstance(other, CompressedSerializer) and self.serializer == other.serializer
+
class UTF8Deserializer(Serializer):
@@ -484,6 +489,8 @@ def loads(self, stream):
length = read_int(stream)
if length == SpecialLengths.END_OF_DATA_SECTION:
raise EOFError
+ elif length == SpecialLengths.NULL:
+ return None
s = stream.read(length)
return s.decode("utf-8") if self.use_unicode else s
@@ -496,6 +503,9 @@ def load_stream(self, stream):
except EOFError:
return
+ def __eq__(self, other):
+ return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode
+
def read_long(stream):
length = stream.read(8)
@@ -526,3 +536,8 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
+
+
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index d57a802e4734a..10a7ccd502000 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -25,7 +25,7 @@
import random
import pyspark.heapq3 as heapq
-from pyspark.serializers import BatchedSerializer, PickleSerializer
+from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
try:
import psutil
@@ -213,8 +213,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
Merger.__init__(self, aggregator)
self.memory_limit = memory_limit
# default serializer is only used for tests
- self.serializer = serializer or \
- BatchedSerializer(PickleSerializer(), 1024)
+ self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
self.localdirs = localdirs or _get_local_dirs(str(id(self)))
# number of partitions when spill data into disks
self.partitions = partitions
@@ -470,7 +469,7 @@ class ExternalSorter(object):
def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
- self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
+ self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
def _get_path(self, n):
""" Choose one directory for spill by number n """
@@ -479,13 +478,21 @@ def _get_path(self, n):
os.makedirs(d)
return os.path.join(d, str(n))
+ def _next_limit(self):
+ """
+ Return the next memory limit. If the memory is not released
+ after spilling, it will dump the data only when the used memory
+ starts to increase.
+ """
+ return max(self.memory_limit, get_used_memory() * 1.05)
+
def sorted(self, iterator, key=None, reverse=False):
"""
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
- batch = 100
+ batch, limit = 100, self._next_limit()
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
@@ -505,6 +512,7 @@ def sorted(self, iterator, key=None, reverse=False):
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []
gc.collect()
+ limit = self._next_limit()
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
DiskBytesSpilled += os.path.getsize(path)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index b31a82f9b19ac..3ac8ea597e142 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -20,21 +20,29 @@
- L{SQLContext}
Main entry point for SQL functionality.
- - L{SchemaRDD}
+ - L{DataFrame}
A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In
- addition to normal RDD operations, SchemaRDDs also support SQL.
+ addition to normal RDD operations, DataFrames also support SQL.
+ - L{GroupedData}
+ - L{Column}
+ Column is a DataFrame with a single column.
- L{Row}
A Row of data returned by a Spark SQL query.
- L{HiveContext}
Main entry point for accessing data stored in Apache Hive..
"""
+import sys
import itertools
import decimal
import datetime
import keyword
import warnings
import json
+import re
+import random
+import os
+from tempfile import NamedTemporaryFile
from array import array
from operator import itemgetter
from itertools import imap
@@ -42,17 +50,20 @@
from py4j.protocol import Py4JError
from py4j.java_collections import ListConverter, MapConverter
-from pyspark.rdd import RDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
+from pyspark.context import SparkContext
+from pyspark.rdd import RDD, _prepare_for_python_RDD
+from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
+ CloudPickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
__all__ = [
- "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
+ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
- "SQLContext", "HiveContext", "SchemaRDD", "Row"]
+ "SQLContext", "HiveContext", "DataFrame", "GroupedData", "Column", "Row", "Dsl",
+ "SchemaRDD"]
class DataType(object):
@@ -108,6 +119,15 @@ def __eq__(self, other):
return self is other
+class NullType(PrimitiveType):
+
+ """Spark SQL NullType
+
+ The data type representing None, used for the types which has not
+ been inferred.
+ """
+
+
class StringType(PrimitiveType):
"""Spark SQL StringType
@@ -132,6 +152,14 @@ class BooleanType(PrimitiveType):
"""
+class DateType(PrimitiveType):
+
+ """Spark SQL DateType
+
+ The data type representing datetime.date values.
+ """
+
+
class TimestampType(PrimitiveType):
"""Spark SQL TimestampType
@@ -140,13 +168,30 @@ class TimestampType(PrimitiveType):
"""
-class DecimalType(PrimitiveType):
+class DecimalType(DataType):
"""Spark SQL DecimalType
The data type representing decimal.Decimal values.
"""
+ def __init__(self, precision=None, scale=None):
+ self.precision = precision
+ self.scale = scale
+ self.hasPrecisionInfo = precision is not None
+
+ def jsonValue(self):
+ if self.hasPrecisionInfo:
+ return "decimal(%d,%d)" % (self.precision, self.scale)
+ else:
+ return "decimal"
+
+ def __repr__(self):
+ if self.hasPrecisionInfo:
+ return "DecimalType(%d,%d)" % (self.precision, self.scale)
+ else:
+ return "DecimalType()"
+
class DoubleType(PrimitiveType):
@@ -305,12 +350,15 @@ class StructField(DataType):
"""
- def __init__(self, name, dataType, nullable):
+ def __init__(self, name, dataType, nullable=True, metadata=None):
"""Creates a StructField
:param name: the name of this field.
:param dataType: the data type of this field.
:param nullable: indicates whether values of this field
can be null.
+ :param metadata: metadata of this field, which is a map from string
+ to simple type that can be serialized to JSON
+ automatically
>>> (StructField("f1", StringType, True)
... == StructField("f1", StringType, True))
@@ -322,6 +370,7 @@ def __init__(self, name, dataType, nullable):
self.name = name
self.dataType = dataType
self.nullable = nullable
+ self.metadata = metadata or {}
def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
@@ -330,13 +379,15 @@ def __repr__(self):
def jsonValue(self):
return {"name": self.name,
"type": self.dataType.jsonValue(),
- "nullable": self.nullable}
+ "nullable": self.nullable,
+ "metadata": self.metadata}
@classmethod
def fromJson(cls, json):
return StructField(json["name"],
_parse_datatype_json_value(json["type"]),
- json["nullable"])
+ json["nullable"],
+ json["metadata"])
class StructType(DataType):
@@ -376,6 +427,75 @@ def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])
+class UserDefinedType(DataType):
+ """
+ .. note:: WARN: Spark Internal Use Only
+ SQL User-Defined Type (UDT).
+ """
+
+ @classmethod
+ def typeName(cls):
+ return cls.__name__.lower()
+
+ @classmethod
+ def sqlType(cls):
+ """
+ Underlying SQL storage type for this UDT.
+ """
+ raise NotImplementedError("UDT must implement sqlType().")
+
+ @classmethod
+ def module(cls):
+ """
+ The Python module of the UDT.
+ """
+ raise NotImplementedError("UDT must implement module().")
+
+ @classmethod
+ def scalaUDT(cls):
+ """
+ The class name of the paired Scala UDT.
+ """
+ raise NotImplementedError("UDT must have a paired Scala UDT.")
+
+ def serialize(self, obj):
+ """
+ Converts the a user-type object into a SQL datum.
+ """
+ raise NotImplementedError("UDT must implement serialize().")
+
+ def deserialize(self, datum):
+ """
+ Converts a SQL datum into a user-type object.
+ """
+ raise NotImplementedError("UDT must implement deserialize().")
+
+ def json(self):
+ return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
+
+ def jsonValue(self):
+ schema = {
+ "type": "udt",
+ "class": self.scalaUDT(),
+ "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+ "sqlType": self.sqlType().jsonValue()
+ }
+ return schema
+
+ @classmethod
+ def fromJson(cls, json):
+ pyUDT = json["pyClass"]
+ split = pyUDT.rfind(".")
+ pyModule = pyUDT[:split]
+ pyClass = pyUDT[split+1:]
+ m = __import__(pyModule, globals(), locals(), [pyClass], -1)
+ UDT = getattr(m, pyClass)
+ return UDT()
+
+ def __eq__(self, other):
+ return type(self) == type(other)
+
+
_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
@@ -415,7 +535,8 @@ def _parse_datatype_json_string(json_string):
... StructField("simpleArray", simple_arraytype, True),
... StructField("simpleMap", simple_maptype, True),
... StructField("simpleStruct", simple_structtype, True),
- ... StructField("boolean", BooleanType(), False)])
+ ... StructField("boolean", BooleanType(), False),
+ ... StructField("withMeta", DoubleType(), False, {"name": "age"})])
>>> check_datatype(complex_structtype)
True
>>> # Complex ArrayType.
@@ -427,19 +548,43 @@ def _parse_datatype_json_string(json_string):
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
True
+ >>> check_datatype(ExamplePointUDT())
+ True
+ >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> check_datatype(structtype_with_udt)
+ True
"""
return _parse_datatype_json_value(json.loads(json_string))
+_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
+
+
def _parse_datatype_json_value(json_value):
- if type(json_value) is unicode and json_value in _all_primitive_types.keys():
- return _all_primitive_types[json_value]()
+ if type(json_value) is unicode:
+ if json_value in _all_primitive_types.keys():
+ return _all_primitive_types[json_value]()
+ elif json_value == u'decimal':
+ return DecimalType()
+ elif _FIXED_DECIMAL.match(json_value):
+ m = _FIXED_DECIMAL.match(json_value)
+ return DecimalType(int(m.group(1)), int(m.group(2)))
+ else:
+ raise ValueError("Could not parse datatype: %s" % json_value)
else:
- return _all_complex_types[json_value["type"]].fromJson(json_value)
+ tpe = json_value["type"]
+ if tpe in _all_complex_types:
+ return _all_complex_types[tpe].fromJson(json_value)
+ elif tpe == 'udt':
+ return UserDefinedType.fromJson(json_value)
+ else:
+ raise ValueError("not supported type: %s" % tpe)
-# Mapping Python types to Spark SQL DateType
+# Mapping Python types to Spark SQL DataType
_type_mappings = {
+ type(None): NullType,
bool: BooleanType,
int: IntegerType,
long: LongType,
@@ -448,30 +593,41 @@ def _parse_datatype_json_value(json_value):
unicode: StringType,
bytearray: BinaryType,
decimal.Decimal: DecimalType,
+ datetime.date: DateType,
datetime.datetime: TimestampType,
- datetime.date: TimestampType,
datetime.time: TimestampType,
}
def _infer_type(obj):
- """Infer the DataType from obj"""
+ """Infer the DataType from obj
+
+ >>> p = ExamplePoint(1.0, 2.0)
+ >>> _infer_type(p)
+ ExamplePointUDT
+ """
if obj is None:
raise ValueError("Can not infer type for None")
+ if hasattr(obj, '__UDT__'):
+ return obj.__UDT__
+
dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()
if isinstance(obj, dict):
- if not obj:
- raise ValueError("Can not infer type for empty dict")
- key, value = obj.iteritems().next()
- return MapType(_infer_type(key), _infer_type(value), True)
+ for key, value in obj.iteritems():
+ if key is not None and value is not None:
+ return MapType(_infer_type(key), _infer_type(value), True)
+ else:
+ return MapType(NullType(), NullType(), True)
elif isinstance(obj, (list, array)):
- if not obj:
- raise ValueError("Can not infer type for empty list/array")
- return ArrayType(_infer_type(obj[0]), True)
+ for v in obj:
+ if v is not None:
+ return ArrayType(_infer_type(obj[0]), True)
+ else:
+ return ArrayType(NullType(), True)
else:
try:
return _infer_schema(obj)
@@ -504,60 +660,181 @@ def _infer_schema(row):
return StructType(fields)
-def _create_converter(obj, dataType):
+def _need_python_to_sql_conversion(dataType):
+ """
+ Checks whether we need python to sql conversion for the given type.
+ For now, only UDTs need this conversion.
+
+ >>> _need_python_to_sql_conversion(DoubleType())
+ False
+ >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
+ ... StructField("values", ArrayType(DoubleType(), False), False)])
+ >>> _need_python_to_sql_conversion(schema0)
+ False
+ >>> _need_python_to_sql_conversion(ExamplePointUDT())
+ True
+ >>> schema1 = ArrayType(ExamplePointUDT(), False)
+ >>> _need_python_to_sql_conversion(schema1)
+ True
+ >>> schema2 = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> _need_python_to_sql_conversion(schema2)
+ True
+ """
+ if isinstance(dataType, StructType):
+ return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+ elif isinstance(dataType, ArrayType):
+ return _need_python_to_sql_conversion(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_python_to_sql_conversion(dataType.keyType) or \
+ _need_python_to_sql_conversion(dataType.valueType)
+ elif isinstance(dataType, UserDefinedType):
+ return True
+ else:
+ return False
+
+
+def _python_to_sql_converter(dataType):
+ """
+ Returns a converter that converts a Python object into a SQL datum for the given type.
+
+ >>> conv = _python_to_sql_converter(DoubleType())
+ >>> conv(1.0)
+ 1.0
+ >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
+ >>> conv([1.0, 2.0])
+ [1.0, 2.0]
+ >>> conv = _python_to_sql_converter(ExamplePointUDT())
+ >>> conv(ExamplePoint(1.0, 2.0))
+ [1.0, 2.0]
+ >>> schema = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> conv = _python_to_sql_converter(schema)
+ >>> conv((1.0, ExamplePoint(1.0, 2.0)))
+ (1.0, [1.0, 2.0])
+ """
+ if not _need_python_to_sql_conversion(dataType):
+ return lambda x: x
+
+ if isinstance(dataType, StructType):
+ names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
+ converters = map(_python_to_sql_converter, types)
+
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+ elif isinstance(obj, tuple):
+ if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
+ return tuple(c(v) for c, v in zip(converters, obj))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
+ d = dict(obj)
+ return tuple(c(d.get(n)) for n, c in zip(names, converters))
+ else:
+ return tuple(c(v) for c, v in zip(converters, obj))
+ else:
+ raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ return converter
+ elif isinstance(dataType, ArrayType):
+ element_converter = _python_to_sql_converter(dataType.elementType)
+ return lambda a: [element_converter(v) for v in a]
+ elif isinstance(dataType, MapType):
+ key_converter = _python_to_sql_converter(dataType.keyType)
+ value_converter = _python_to_sql_converter(dataType.valueType)
+ return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+ elif isinstance(dataType, UserDefinedType):
+ return lambda obj: dataType.serialize(obj)
+ else:
+ raise ValueError("Unexpected type %r" % dataType)
+
+
+def _has_nulltype(dt):
+ """ Return whether there is NullType in `dt` or not """
+ if isinstance(dt, StructType):
+ return any(_has_nulltype(f.dataType) for f in dt.fields)
+ elif isinstance(dt, ArrayType):
+ return _has_nulltype((dt.elementType))
+ elif isinstance(dt, MapType):
+ return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
+ else:
+ return isinstance(dt, NullType)
+
+
+def _merge_type(a, b):
+ if isinstance(a, NullType):
+ return b
+ elif isinstance(b, NullType):
+ return a
+ elif type(a) is not type(b):
+ # TODO: type cast (such as int -> long)
+ raise TypeError("Can not merge type %s and %s" % (a, b))
+
+ # same type
+ if isinstance(a, StructType):
+ nfs = dict((f.name, f.dataType) for f in b.fields)
+ fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
+ for f in a.fields]
+ names = set([f.name for f in fields])
+ for n in nfs:
+ if n not in names:
+ fields.append(StructField(n, nfs[n]))
+ return StructType(fields)
+
+ elif isinstance(a, ArrayType):
+ return ArrayType(_merge_type(a.elementType, b.elementType), True)
+
+ elif isinstance(a, MapType):
+ return MapType(_merge_type(a.keyType, b.keyType),
+ _merge_type(a.valueType, b.valueType),
+ True)
+ else:
+ return a
+
+
+def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
if isinstance(dataType, ArrayType):
- conv = _create_converter(obj[0], dataType.elementType)
+ conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
elif isinstance(dataType, MapType):
- value = obj.values()[0]
- conv = _create_converter(value, dataType.valueType)
- return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
+ kconv = _create_converter(dataType.keyType)
+ vconv = _create_converter(dataType.valueType)
+ return lambda row: dict((kconv(k), vconv(v)) for k, v in row.iteritems())
+
+ elif isinstance(dataType, NullType):
+ return lambda x: None
elif not isinstance(dataType, StructType):
return lambda x: x
# dataType must be StructType
names = [f.name for f in dataType.fields]
+ converters = [_create_converter(f.dataType) for f in dataType.fields]
+
+ def convert_struct(obj):
+ if obj is None:
+ return
+
+ if isinstance(obj, tuple):
+ if hasattr(obj, "_fields"):
+ d = dict(zip(obj._fields, obj))
+ elif hasattr(obj, "__FIELDS__"):
+ d = dict(zip(obj.__FIELDS__, obj))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
+ d = dict(obj)
+ else:
+ raise ValueError("unexpected tuple: %s" % str(obj))
- if isinstance(obj, dict):
- conv = lambda o: tuple(o.get(n) for n in names)
-
- elif isinstance(obj, tuple):
- if hasattr(obj, "_fields"): # namedtuple
- conv = tuple
- elif hasattr(obj, "__FIELDS__"):
- conv = tuple
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
- conv = lambda o: tuple(v for k, v in o)
+ elif isinstance(obj, dict):
+ d = obj
+ elif hasattr(obj, "__dict__"): # object
+ d = obj.__dict__
else:
- raise ValueError("unexpected tuple")
-
- elif hasattr(obj, "__dict__"): # object
- conv = lambda o: [o.__dict__.get(n, None) for n in names]
-
- if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields):
- return conv
-
- row = conv(obj)
- convs = [_create_converter(v, f.dataType)
- for v, f in zip(row, dataType.fields)]
-
- def nested_conv(row):
- return tuple(f(v) for f, v in zip(convs, conv(row)))
+ raise ValueError("Unexpected obj: %s" % obj)
- return nested_conv
+ return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
-
-def _drop_schema(rows, schema):
- """ all the names of fields, becoming tuples"""
- iterator = iter(rows)
- row = iterator.next()
- converter = _create_converter(row, schema)
- yield converter(row)
- for i in iterator:
- yield converter(i)
+ return convert_struct
_BRACKETS = {'(': ')', '[': ']', '{': '}'}
@@ -654,12 +931,12 @@ def _parse_schema_abstract(s):
def _infer_schema_type(obj, dataType):
"""
- Fill the dataType with types infered from obj
+ Fill the dataType with types inferred from obj
- >>> schema = _parse_schema_abstract("a b c")
- >>> row = (1, 1.0, "str")
+ >>> schema = _parse_schema_abstract("a b c d")
+ >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
>>> _infer_schema_type(row, schema)
- StructType...IntegerType...DoubleType...StringType...
+ StructType...IntegerType...DoubleType...StringType...DateType...
>>> row = [[1], {"key": (1, 2.0)}]
>>> schema = _parse_schema_abstract("a[] b{c d}")
>>> _infer_schema_type(row, schema)
@@ -669,7 +946,7 @@ def _infer_schema_type(obj, dataType):
return _infer_type(obj)
if not obj:
- raise ValueError("Can not infer type from empty value")
+ return NullType()
if isinstance(dataType, ArrayType):
eType = _infer_schema_type(obj[0], dataType.elementType)
@@ -677,7 +954,7 @@ def _infer_schema_type(obj, dataType):
elif isinstance(dataType, MapType):
k, v = obj.iteritems().next()
- return MapType(_infer_type(k),
+ return MapType(_infer_schema_type(k, dataType.keyType),
_infer_schema_type(v, dataType.valueType))
elif isinstance(dataType, StructType):
@@ -703,6 +980,7 @@ def _infer_schema_type(obj, dataType):
DecimalType: (decimal.Decimal,),
StringType: (str, unicode),
BinaryType: (bytearray,),
+ DateType: (datetime.date,),
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
@@ -730,17 +1008,28 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
+ >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+ >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
"""
# all objects are nullable
if obj is None:
return
+ if isinstance(dataType, UserDefinedType):
+ if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
+ raise ValueError("%r is not an instance of type %r" % (obj, dataType))
+ _verify_type(dataType.serialize(obj), dataType.sqlType())
+ return
+
_type = type(dataType)
assert _type in _acceptable_types, "unkown datatype: %s" % dataType
# subclass of them can not be deserialized in JVM
if type(obj) not in _acceptable_types[_type]:
- raise TypeError("%s can not accept abject in type %s"
+ raise TypeError("%s can not accept object in type %s"
% (dataType, type(obj)))
if isinstance(dataType, ArrayType):
@@ -767,7 +1056,7 @@ def _restore_object(dataType, obj):
""" Restore object during unpickling. """
# use id(dataType) as key to speed up lookup in dict
# Because of batched pickling, dataType will be the
- # same object in mose cases.
+ # same object in most cases.
k = id(dataType)
cls = _cached_cls.get(k)
if cls is None:
@@ -782,6 +1071,10 @@ def _restore_object(dataType, obj):
def _create_object(cls, v):
""" Create an customized object with class `cls`. """
+ # datetime.date would be deserialized as datetime.datetime
+ # from java type, so we need to set it back.
+ if cls is datetime.date and isinstance(v, datetime.datetime):
+ return v.date()
return cls(v) if v is not None else v
@@ -795,14 +1088,18 @@ def getter(self):
return getter
-def _has_struct(dt):
- """Return whether `dt` is or has StructType in it"""
+def _has_struct_or_date(dt):
+ """Return whether `dt` is or has StructType/DateType in it"""
if isinstance(dt, StructType):
return True
elif isinstance(dt, ArrayType):
- return _has_struct(dt.elementType)
+ return _has_struct_or_date(dt.elementType)
elif isinstance(dt, MapType):
- return _has_struct(dt.valueType)
+ return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
+ elif isinstance(dt, DateType):
+ return True
+ elif isinstance(dt, UserDefinedType):
+ return True
return False
@@ -815,7 +1112,7 @@ def _create_properties(fields):
or keyword.iskeyword(name)):
warnings.warn("field name %s can not be accessed in Python,"
"use position to access it instead" % name)
- if _has_struct(f.dataType):
+ if _has_struct_or_date(f.dataType):
# delay creating object until accessing it
getter = _create_getter(f.dataType, i)
else:
@@ -861,21 +1158,29 @@ def List(l):
return List
elif isinstance(dataType, MapType):
- cls = _create_cls(dataType.valueType)
+ kcls = _create_cls(dataType.keyType)
+ vcls = _create_cls(dataType.valueType)
def Dict(d):
if d is None:
return
- return dict((k, _create_object(cls, v)) for k, v in d.items())
+ return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
return Dict
+ elif isinstance(dataType, DateType):
+ return datetime.date
+
+ elif isinstance(dataType, UserDefinedType):
+ return lambda datum: dataType.deserialize(datum)
+
elif not isinstance(dataType, StructType):
- raise Exception("unexpected data type: %s" % dataType)
+ # no wrapper for primitive types
+ return lambda x: x
class Row(tuple):
- """ Row in SchemaRDD """
+ """ Row in DataFrame """
__DATATYPE__ = dataType
__FIELDS__ = tuple(f.name for f in dataType.fields)
__slots__ = ()
@@ -883,6 +1188,10 @@ class Row(tuple):
# create property for fast access
locals().update(_create_properties(dataType.fields))
+ def asDict(self):
+ """ Return as a dict """
+ return dict((n, getattr(self, n)) for n in self.__FIELDS__)
+
def __repr__(self):
# call collect __repr__ for nested objects
return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
@@ -898,7 +1207,7 @@ class SQLContext(object):
"""Main entry point for Spark SQL functionality.
- A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as
+ A SQLContext can be used create L{DataFrame}, register L{DataFrame} as
tables, execute SQL over tables, cache tables, and read parquet files.
"""
@@ -909,8 +1218,8 @@ def __init__(self, sparkContext, sqlContext=None):
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
@@ -925,19 +1234,18 @@ def __init__(self, sparkContext, sqlContext=None):
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> srdd = sqlCtx.inferSchema(allTypes)
- >>> srdd.registerTempTable("allTypes")
+ >>> df = sqlCtx.inferSchema(allTypes)
+ >>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
[Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)]
- >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
+ >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time,
... x.row.a, x.list)).collect()
[(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])]
"""
self._sc = sparkContext
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
- self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray
self._scala_SQLContext = sqlContext
@property
@@ -966,70 +1274,69 @@ def registerFunction(self, name, f, returnType=StringType()):
[Row(c0=4)]
"""
func = lambda _, it: imap(lambda x: f(*x), it)
- command = (func, None,
- BatchedSerializer(PickleSerializer(), 1024),
- BatchedSerializer(PickleSerializer(), 1024))
- ser = CloudPickleSerializer()
- pickled_command = ser.dumps(command)
- if len(pickled_command) > (1 << 20): # 1M
- broadcast = self._sc.broadcast(pickled_command)
- pickled_command = ser.dumps(broadcast)
- broadcast_vars = ListConverter().convert(
- [x._jbroadcast for x in self._sc._pickled_broadcast_vars],
- self._sc._gateway._gateway_client)
- self._sc._pickled_broadcast_vars.clear()
- env = MapConverter().convert(self._sc.environment,
- self._sc._gateway._gateway_client)
- includes = ListConverter().convert(self._sc._python_includes,
- self._sc._gateway._gateway_client)
- self._ssql_ctx.registerPython(name,
- bytearray(pickled_command),
- env,
- includes,
- self._sc.pythonExec,
- broadcast_vars,
- self._sc._javaAccumulator,
- returnType.json())
-
- def inferSchema(self, rdd):
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, None, ser, ser)
+ pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
+ self._ssql_ctx.udf().registerPython(name,
+ bytearray(pickled_cmd),
+ env,
+ includes,
+ self._sc.pythonExec,
+ bvars,
+ self._sc._javaAccumulator,
+ returnType.json())
+
+ def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
- We peek at the first row of the RDD to determine the fields' names
- and types. Nested collections are supported, which include array,
- dict, list, Row, tuple, namedtuple, or object.
+ When samplingRatio is specified, the schema is inferred by looking
+ at the types of each row in the sampled dataset. Otherwise, the
+ first 100 rows of the RDD are inspected. Nested collections are
+ supported, which can include array, dict, list, Row, tuple,
+ namedtuple, or object.
- All the rows in `rdd` should have the same type with the first one,
- or it will cause runtime exceptions.
+ Each row could be L{pyspark.sql.Row} object or namedtuple or objects.
+ Using top level dicts is deprecated, as dict is used to represent Maps.
- Each row could be L{pyspark.sql.Row} object or namedtuple or objects,
- using dict is deprecated.
+ If a single column has multiple distinct inferred types, it may cause
+ runtime exceptions.
>>> rdd = sc.parallelize(
... [Row(field1=1, field2="row1"),
... Row(field1=2, field2="row2"),
... Row(field1=3, field2="row3")])
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()[0]
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
Row(field1=1, field2=u'row1')
>>> NestedRow = Row("f1", "f2")
>>> nestedRdd1 = sc.parallelize([
... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
- >>> srdd = sqlCtx.inferSchema(nestedRdd1)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(nestedRdd1)
+ >>> df.collect()
[Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
>>> nestedRdd2 = sc.parallelize([
... NestedRow([[1, 2], [2, 3]], [1, 2]),
... NestedRow([[2, 3], [3, 4]], [2, 3])])
- >>> srdd = sqlCtx.inferSchema(nestedRdd2)
- >>> srdd.collect()
+ >>> df = sqlCtx.inferSchema(nestedRdd2)
+ >>> df.collect()
[Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
+
+ >>> from collections import namedtuple
+ >>> CustomRow = namedtuple('CustomRow', 'field1 field2')
+ >>> rdd = sc.parallelize(
+ ... [CustomRow(field1=1, field2="row1"),
+ ... CustomRow(field1=2, field2="row2"),
+ ... CustomRow(field1=3, field2="row3")])
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.collect()[0]
+ Row(field1=1, field2=u'row1')
"""
- if isinstance(rdd, SchemaRDD):
- raise TypeError("Cannot apply schema to SchemaRDD")
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
first = rdd.first()
if not first:
@@ -1039,8 +1346,23 @@ def inferSchema(self, rdd):
warnings.warn("Using RDD of dict to inferSchema is deprecated,"
"please use pyspark.sql.Row instead")
- schema = _infer_schema(first)
- rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))
+ if samplingRatio is None:
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for row in rdd.take(100)[1:]:
+ schema = _merge_type(schema, _infer_schema(row))
+ if not _has_nulltype(schema):
+ break
+ else:
+ warnings.warn("Some of types cannot be determined by the "
+ "first 100 rows, please try again with sampling")
+ else:
+ if samplingRatio > 0.99:
+ rdd = rdd.sample(False, float(samplingRatio))
+ schema = rdd.map(_infer_schema).reduce(_merge_type)
+
+ converter = _create_converter(schema)
+ rdd = rdd.map(converter)
return self.applySchema(rdd, schema)
def applySchema(self, rdd, schema):
@@ -1058,14 +1380,15 @@ def applySchema(self, rdd, schema):
>>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")])
>>> schema = StructType([StructField("field1", IntegerType(), False),
... StructField("field2", StringType(), False)])
- >>> srdd = sqlCtx.applySchema(rdd2, schema)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.sql("SELECT * from table1")
- >>> srdd2.collect()
+ >>> df = sqlCtx.applySchema(rdd2, schema)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT * from table1")
+ >>> df2.collect()
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
- >>> from datetime import datetime
+ >>> from datetime import date, datetime
>>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
+ ... date(2010, 1, 1),
... datetime(2010, 1, 1, 1, 1, 1),
... {"a": 1}, (2,), [1, 2, 3], None)])
>>> schema = StructType([
@@ -1075,6 +1398,7 @@ def applySchema(self, rdd, schema):
... StructField("short2", ShortType(), False),
... StructField("int", IntegerType(), False),
... StructField("float", FloatType(), False),
+ ... StructField("date", DateType(), False),
... StructField("time", TimestampType(), False),
... StructField("map",
... MapType(StringType(), IntegerType(), False), False),
@@ -1082,14 +1406,15 @@ def applySchema(self, rdd, schema):
... StructType([StructField("b", ShortType(), False)]), False),
... StructField("list", ArrayType(ByteType(), False), False),
... StructField("null", DoubleType(), True)])
- >>> srdd = sqlCtx.applySchema(rdd, schema)
- >>> results = srdd.map(
- ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time,
- ... x.map["a"], x.struct.b, x.list, x.null))
- >>> results.collect()[0]
- (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
-
- >>> srdd.registerTempTable("table2")
+ >>> df = sqlCtx.applySchema(rdd, schema)
+ >>> results = df.map(
+ ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
+ ... x.time, x.map["a"], x.struct.b, x.list, x.null))
+ >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
+ (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
+ datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+
+ >>> df.registerTempTable("table2")
>>> sqlCtx.sql(
... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
@@ -1102,13 +1427,13 @@ def applySchema(self, rdd, schema):
>>> abstract = "byte short float time map{} struct(b) list[]"
>>> schema = _parse_schema_abstract(abstract)
>>> typedSchema = _infer_schema_type(rdd.first(), schema)
- >>> srdd = sqlCtx.applySchema(rdd, typedSchema)
- >>> srdd.collect()
+ >>> df = sqlCtx.applySchema(rdd, typedSchema)
+ >>> df.collect()
[Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
"""
- if isinstance(rdd, SchemaRDD):
- raise TypeError("Cannot apply schema to SchemaRDD")
+ if isinstance(rdd, DataFrame):
+ raise TypeError("Cannot apply schema to DataFrame")
if not isinstance(schema, StructType):
raise TypeError("schema should be StructType")
@@ -1123,10 +1448,13 @@ def applySchema(self, rdd, schema):
for row in rows:
_verify_type(row, schema)
- batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
- jrdd = self._pythonToJava(rdd._jrdd, batched)
- srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
- return SchemaRDD(srdd.toJavaSchemaRDD(), self)
+ # convert python objects to sql data
+ converter = _python_to_sql_converter(schema)
+ rdd = rdd.map(converter)
+
+ jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
+ df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
+ return DataFrame(df, self)
def registerRDDAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
@@ -1134,40 +1462,40 @@ def registerRDDAsTable(self, rdd, tableName):
Temporary tables exist only during the lifetime of this instance of
SQLContext.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
"""
- if (rdd.__class__ is SchemaRDD):
- srdd = rdd._jschema_rdd.baseSchemaRDD()
- self._ssql_ctx.registerRDDAsTable(srdd, tableName)
+ if (rdd.__class__ is DataFrame):
+ df = rdd._jdf
+ self._ssql_ctx.registerRDDAsTable(df, tableName)
else:
- raise ValueError("Can only register SchemaRDD as table")
+ raise ValueError("Can only register DataFrame as table")
def parquetFile(self, path):
- """Loads a Parquet file, returning the result as a L{SchemaRDD}.
+ """Loads a Parquet file, returning the result as a L{DataFrame}.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.saveAsParquetFile(parquetFile)
- >>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
- return SchemaRDD(jschema_rdd, self)
+ jdf = self._ssql_ctx.parquetFile(path)
+ return DataFrame(jdf, self)
- def jsonFile(self, path, schema=None):
+ def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""
Loads a text file storing one JSON object per line as a
- L{SchemaRDD}.
+ L{DataFrame}.
If the schema is provided, applies the given schema to this
JSON dataset.
- Otherwise, it goes through the entire dataset once to determine
- the schema.
+ Otherwise, it samples the dataset with ratio `samplingRatio` to
+ determine the schema.
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
@@ -1176,23 +1504,23 @@ def jsonFile(self, path, schema=None):
>>> for json in jsonStrings:
... print>>ofn, json
>>> ofn.close()
- >>> srdd1 = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql(
+ >>> df1 = sqlCtx.jsonFile(jsonFile)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table1")
- >>> for r in srdd2.collect():
+ >>> for r in df2.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema())
- >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
- >>> srdd4 = sqlCtx.sql(
+ >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table2")
- >>> for r in srdd4.collect():
+ >>> for r in df4.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
@@ -1204,47 +1532,47 @@ def jsonFile(self, path, schema=None):
... StructType([
... StructField("field5",
... ArrayType(IntegerType(), False), True)]), False)])
- >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema)
- >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
- >>> srdd6 = sqlCtx.sql(
+ >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
... "SELECT field2 AS f1, field3.field5 as f2, "
... "field3.field5[0] as f3 from table3")
- >>> srdd6.collect()
+ >>> df6.collect()
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
- srdd = self._ssql_ctx.jsonFile(path)
+ df = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
- return SchemaRDD(srdd.toJavaSchemaRDD(), self)
+ df = self._ssql_ctx.jsonFile(path, scala_datatype)
+ return DataFrame(df, self)
- def jsonRDD(self, rdd, schema=None):
- """Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
+ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
+ """Loads an RDD storing one JSON object per string as a L{DataFrame}.
If the schema is provided, applies the given schema to this
JSON dataset.
- Otherwise, it goes through the entire dataset once to determine
- the schema.
+ Otherwise, it samples the dataset with ratio `samplingRatio` to
+ determine the schema.
- >>> srdd1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
- >>> srdd2 = sqlCtx.sql(
+ >>> df1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table1")
- >>> for r in srdd2.collect():
+ >>> for r in df2.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
- >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema())
- >>> sqlCtx.registerRDDAsTable(srdd3, "table2")
- >>> srdd4 = sqlCtx.sql(
+ >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
+ >>> sqlCtx.registerRDDAsTable(df3, "table2")
+ >>> df4 = sqlCtx.sql(
... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
... "field6 as f4 from table2")
- >>> for r in srdd4.collect():
+ >>> for r in df4.collect():
... print r
Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
@@ -1256,12 +1584,12 @@ def jsonRDD(self, rdd, schema=None):
... StructType([
... StructField("field5",
... ArrayType(IntegerType(), False), True)]), False)])
- >>> srdd5 = sqlCtx.jsonRDD(json, schema)
- >>> sqlCtx.registerRDDAsTable(srdd5, "table3")
- >>> srdd6 = sqlCtx.sql(
+ >>> df5 = sqlCtx.jsonRDD(json, schema)
+ >>> sqlCtx.registerRDDAsTable(df5, "table3")
+ >>> df6 = sqlCtx.sql(
... "SELECT field2 AS f1, field3.field5 as f2, "
... "field3.field5[0] as f3 from table3")
- >>> srdd6.collect()
+ >>> df6.collect()
[Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
>>> sqlCtx.jsonRDD(sc.parallelize(['{}',
@@ -1283,33 +1611,33 @@ def func(iterator):
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
- return SchemaRDD(srdd.toJavaSchemaRDD(), self)
+ df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
+ return DataFrame(df, self)
def sql(self, sqlQuery):
- """Return a L{SchemaRDD} representing the result of the given query.
+ """Return a L{DataFrame} representing the result of the given query.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
- >>> srdd2.collect()
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
+ >>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
"""
- return SchemaRDD(self._ssql_ctx.sql(sqlQuery).toJavaSchemaRDD(), self)
+ return DataFrame(self._ssql_ctx.sql(sqlQuery), self)
def table(self, tableName):
- """Returns the specified table as a L{SchemaRDD}.
+ """Returns the specified table as a L{DataFrame}.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(srdd, "table1")
- >>> srdd2 = sqlCtx.table("table1")
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> df2 = sqlCtx.table("table1")
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- return SchemaRDD(self._ssql_ctx.table(tableName).toJavaSchemaRDD(), self)
+ return DataFrame(self._ssql_ctx.table(tableName), self)
def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
@@ -1349,73 +1677,11 @@ def _ssql_ctx(self):
except Py4JError as e:
raise Exception("You must build Spark with Hive. "
"Export 'SPARK_HIVE=true' and run "
- "sbt/sbt assembly", e)
+ "build/sbt assembly", e)
def _get_hive_ctx(self):
return self._jvm.HiveContext(self._jsc.sc())
- def hiveql(self, hqlQuery):
- """
- DEPRECATED: Use sql()
- """
- warnings.warn("hiveql() is deprecated as the sql function now parses using HiveQL by" +
- "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
- DeprecationWarning)
- return SchemaRDD(self._ssql_ctx.hiveql(hqlQuery).toJavaSchemaRDD(), self)
-
- def hql(self, hqlQuery):
- """
- DEPRECATED: Use sql()
- """
- warnings.warn("hql() is deprecated as the sql function now parses using HiveQL by" +
- "default. The SQL dialect for parsing can be set using 'spark.sql.dialect'",
- DeprecationWarning)
- return self.hiveql(hqlQuery)
-
-
-class LocalHiveContext(HiveContext):
-
- """Starts up an instance of hive where metadata is stored locally.
-
- An in-process metadata data is created with data stored in ./metadata.
- Warehouse data is stored in in ./warehouse.
-
- >>> import os
- >>> hiveCtx = LocalHiveContext(sc)
- >>> try:
- ... supress = hiveCtx.sql("DROP TABLE src")
- ... except Exception:
- ... pass
- >>> kv1 = os.path.join(os.environ["SPARK_HOME"],
- ... 'examples/src/main/resources/kv1.txt')
- >>> supress = hiveCtx.sql(
- ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
- >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src"
- ... % kv1)
- >>> results = hiveCtx.sql("FROM src SELECT value"
- ... ).map(lambda r: int(r.value.split('_')[1]))
- >>> num = results.count()
- >>> reduce_sum = results.reduce(lambda x, y: x + y)
- >>> num
- 500
- >>> reduce_sum
- 130091
- """
-
- def __init__(self, sparkContext, sqlContext=None):
- HiveContext.__init__(self, sparkContext, sqlContext)
- warnings.warn("LocalHiveContext is deprecated. "
- "Use HiveContext instead.", DeprecationWarning)
-
- def _get_hive_ctx(self):
- return self._jvm.LocalHiveContext(self._jsc.sc())
-
-
-class TestHiveContext(HiveContext):
-
- def _get_hive_ctx(self):
- return self._jvm.TestHiveContext(self._jsc.sc())
-
def _create_row(fields, values):
row = Row(*values)
@@ -1426,7 +1692,7 @@ def _create_row(fields, values):
class Row(tuple):
"""
- A row in L{SchemaRDD}. The fields in it can be accessed like attributes.
+ A row in L{DataFrame}. The fields in it can be accessed like attributes.
Row can be used to create a row object by using named arguments,
the fields will be sorted by names.
@@ -1466,6 +1732,14 @@ def __new__(self, *args, **kwargs):
else:
raise ValueError("No args or kwargs")
+ def asDict(self):
+ """
+ Return as an dict
+ """
+ if not hasattr(self, "__FIELDS__"):
+ raise TypeError("Cannot convert a Row class into dict")
+ return dict(zip(self.__FIELDS__, self))
+
# let obect acs like class
def __call__(self, *args):
"""create new Row object"""
@@ -1496,110 +1770,107 @@ def __repr__(self):
return "" % ", ".join(self)
-def inherit_doc(cls):
- for name, func in vars(cls).items():
- # only inherit docstring for public functions
- if name.startswith("_"):
- continue
- if not func.__doc__:
- for parent in cls.__bases__:
- parent_func = getattr(parent, name, None)
- if parent_func and getattr(parent_func, "__doc__", None):
- func.__doc__ = parent_func.__doc__
- break
- return cls
+class DataFrame(object):
+
+ """A collection of rows that have the same columns.
+
+ A :class:`DataFrame` is equivalent to a relational table in Spark SQL,
+ and can be created using various functions in :class:`SQLContext`::
+
+ people = sqlContext.parquetFile("...")
+
+ Once created, it can be manipulated using the various domain-specific-language
+ (DSL) functions defined in: :class:`DataFrame`, :class:`Column`.
+
+ To select a column from the data frame, use the apply method::
+ ageCol = people.age
-@inherit_doc
-class SchemaRDD(RDD):
+ Note that the :class:`Column` type can also be manipulated
+ through its various functions::
- """An RDD of L{Row} objects that has an associated schema.
+ # The following creates a new column that increases everybody's age by 10.
+ people.age + 10
- The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can
- utilize the relational query api exposed by Spark SQL.
- For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the
- L{SchemaRDD} is not operated on directly, as it's underlying
- implementation is an RDD composed of Java objects. Instead it is
- converted to a PythonRDD in the JVM, on which Python operations can
- be done.
+ A more concrete example::
- This class receives raw tuples from Java but assigns a class to it in
- all its data-collection methods (mapPartitionsWithIndex, collect, take,
- etc) so that PySpark sees them as Row objects with named fields.
+ # To create DataFrame using SQLContext
+ people = sqlContext.parquetFile("...")
+ department = sqlContext.parquetFile("...")
+
+ people.filter(people.age > 30).join(department, people.deptId == department.id)) \
+ .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"})
"""
- def __init__(self, jschema_rdd, sql_ctx):
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
self.sql_ctx = sql_ctx
- self._sc = sql_ctx._sc
- clsName = jschema_rdd.getClass().getName()
- assert clsName.endswith("JavaSchemaRDD"), "jschema_rdd must be JavaSchemaRDD"
- self._jschema_rdd = jschema_rdd
- self._id = None
+ self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
- self.is_checkpointed = False
- self.ctx = self.sql_ctx._sc
- # the _jrdd is created by javaToPython(), serialized by pickle
- self._jrdd_deserializer = BatchedSerializer(PickleSerializer())
@property
- def _jrdd(self):
- """Lazy evaluation of PythonRDD object.
-
- Only done when a user calls methods defined by the
- L{pyspark.rdd.RDD} super class (map, filter, etc.).
+ def rdd(self):
+ """
+ Return the content of the :class:`DataFrame` as an :class:`RDD`
+ of :class:`Row` s.
"""
- if not hasattr(self, '_lazy_jrdd'):
- self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython()
- return self._lazy_jrdd
+ if not hasattr(self, '_lazy_rdd'):
+ jrdd = self._jdf.javaToPython()
+ rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
+ schema = self.schema()
- def id(self):
- if self._id is None:
- self._id = self._jrdd.id()
- return self._id
+ def applySchema(it):
+ cls = _create_cls(schema)
+ return itertools.imap(cls, it)
- def limit(self, num):
- """Limit the result count to the number specified.
+ self._lazy_rdd = rdd.mapPartitions(applySchema)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.limit(2).collect()
- [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
- >>> srdd.limit(0).collect()
- []
+ return self._lazy_rdd
+
+ def toJSON(self, use_unicode=False):
+ """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
+
+ >>> df1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(df1, "table1")
+ >>> df2 = sqlCtx.sql( "SELECT * from table1")
+ >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
+ True
+ >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1")
+ >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
+ True
"""
- rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD()
- return SchemaRDD(rdd, self.sql_ctx)
+ rdd = self._jdf.toJSON()
+ return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.
Files that are written out using this method can be read back in as
- a SchemaRDD using the L{SQLContext.parquetFile} method.
+ a DataFrame using the L{SQLContext.parquetFile} method.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.saveAsParquetFile(parquetFile)
- >>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> sorted(srdd2.collect()) == sorted(srdd.collect())
+ >>> df.saveAsParquetFile(parquetFile)
+ >>> df2 = sqlCtx.parquetFile(parquetFile)
+ >>> sorted(df2.collect()) == sorted(df.collect())
True
"""
- self._jschema_rdd.saveAsParquetFile(path)
+ self._jdf.saveAsParquetFile(path)
def registerTempTable(self, name):
"""Registers this RDD as a temporary table using the given name.
The lifetime of this temporary table is tied to the L{SQLContext}
- that was used to create this SchemaRDD.
+ that was used to create this DataFrame.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.registerTempTable("test")
- >>> srdd2 = sqlCtx.sql("select * from test")
- >>> sorted(srdd.collect()) == sorted(srdd2.collect())
+ >>> df.registerTempTable("people")
+ >>> df2 = sqlCtx.sql("select * from people")
+ >>> sorted(df.collect()) == sorted(df2.collect())
True
"""
- self._jschema_rdd.registerTempTable(name)
+ self._jdf.registerTempTable(name)
def registerAsTable(self, name):
"""DEPRECATED: use registerTempTable() instead"""
@@ -1607,62 +1878,79 @@ def registerAsTable(self, name):
self.registerTempTable(name)
def insertInto(self, tableName, overwrite=False):
- """Inserts the contents of this SchemaRDD into the specified table.
+ """Inserts the contents of this DataFrame into the specified table.
Optionally overwriting any existing data.
"""
- self._jschema_rdd.insertInto(tableName, overwrite)
+ self._jdf.insertInto(tableName, overwrite)
def saveAsTable(self, tableName):
- """Creates a new table with the contents of this SchemaRDD."""
- self._jschema_rdd.saveAsTable(tableName)
+ """Creates a new table with the contents of this DataFrame."""
+ self._jdf.saveAsTable(tableName)
def schema(self):
- """Returns the schema of this SchemaRDD (represented by
- a L{StructType})."""
- return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())
+ """Returns the schema of this DataFrame (represented by
+ a L{StructType}).
- def schemaString(self):
- """Returns the output schema in the tree format."""
- return self._jschema_rdd.schemaString()
+ >>> df.schema()
+ StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
+ """
+ return _parse_datatype_json_string(self._jdf.schema().json())
def printSchema(self):
- """Prints out the schema in the tree format."""
- print self.schemaString()
+ """Prints out the schema in the tree format.
+
+ >>> df.printSchema()
+ root
+ |-- age: integer (nullable = true)
+ |-- name: string (nullable = true)
+
+ """
+ print (self._jdf.schema().treeString())
def count(self):
"""Return the number of elements in this RDD.
Unlike the base RDD implementation of count, this implementation
- leverages the query optimizer to compute the count on the SchemaRDD,
+ leverages the query optimizer to compute the count on the DataFrame,
which supports features such as filter pushdown.
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.count()
- 3L
- >>> srdd.count() == srdd.map(lambda x: x).count()
- True
+ >>> df.count()
+ 2L
"""
- return self._jschema_rdd.count()
+ return self._jdf.count()
def collect(self):
- """Return a list that contains all of the rows in this RDD.
+ """Return a list that contains all of the rows.
Each object in the list is a Row, the fields can be accessed as
attributes.
- Unlike the base RDD implementation of collect, this implementation
- leverages the query optimizer to perform a collect on the SchemaRDD,
- which supports features such as filter pushdown.
-
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.collect()
- [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
+ >>> df.collect()
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
- with SCCallSiteSync(self.context) as css:
- bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
+ with SCCallSiteSync(self._sc) as css:
+ bytesInJava = self._jdf.javaToPython().collect().iterator()
+ tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
+ tempFile.close()
+ self._sc._writeToFile(bytesInJava, tempFile.name)
+ # Read the data into Python and deserialize it:
+ with open(tempFile.name, 'rb') as tempFile:
+ rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
+ os.unlink(tempFile.name)
cls = _create_cls(self.schema())
- return map(cls, self._collect_iterator_through_file(bytesInJava))
+ return [cls(r) for r in rs]
+
+ def limit(self, num):
+ """Limit the result count to the number specified.
+
+ >>> df.limit(1).collect()
+ [Row(age=2, name=u'Alice')]
+ >>> df.limit(0).collect()
+ []
+ """
+ jdf = self._jdf.limit(num)
+ return DataFrame(jdf, self.sql_ctx)
def take(self, num):
"""Take the first num rows of the RDD.
@@ -1670,122 +1958,734 @@ def take(self, num):
Each object in the list is a Row, the fields can be accessed as
attributes.
- Unlike the base RDD implementation of take, this implementation
- leverages the query optimizer to perform a collect on a SchemaRDD,
- which supports features such as filter pushdown.
-
- >>> srdd = sqlCtx.inferSchema(rdd)
- >>> srdd.take(2)
- [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')]
+ >>> df.take(2)
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
return self.limit(num).collect()
- # Convert each object in the RDD to a Row with the right class
- # for this SchemaRDD, so that fields can be accessed as attributes.
- def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
- """
- Return a new RDD by applying a function to each partition of this RDD,
- while tracking the index of the original partition.
+ def map(self, f):
+ """ Return a new RDD by applying a function to each Row, it's a
+ shorthand for df.rdd.map()
- >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
- >>> def f(splitIndex, iterator): yield splitIndex
- >>> rdd.mapPartitionsWithIndex(f).sum()
- 6
+ >>> df.map(lambda p: p.name).collect()
+ [u'Alice', u'Bob']
"""
- rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer)
-
- schema = self.schema()
+ return self.rdd.map(f)
- def applySchema(_, it):
- cls = _create_cls(schema)
- return itertools.imap(cls, it)
+ def mapPartitions(self, f, preservesPartitioning=False):
+ """
+ Return a new RDD by applying a function to each partition.
- objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning)
- return objrdd.mapPartitionsWithIndex(f, preservesPartitioning)
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 4)
+ >>> def f(iterator): yield 1
+ >>> rdd.mapPartitions(f).sum()
+ 4
+ """
+ return self.rdd.mapPartitions(f, preservesPartitioning)
- # We override the default cache/persist/checkpoint behavior
- # as we want to cache the underlying SchemaRDD object in the JVM,
- # not the PythonRDD checkpointed by the super class
def cache(self):
+ """ Persist with the default storage level (C{MEMORY_ONLY_SER}).
+ """
self.is_cached = True
- self._jschema_rdd.cache()
+ self._jdf.cache()
return self
def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
+ """ Set the storage level to persist its values across operations
+ after the first time it is computed. This can only be used to assign
+ a new storage level if the RDD does not have a storage level set yet.
+ If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
+ """
self.is_cached = True
- javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel)
- self._jschema_rdd.persist(javaStorageLevel)
+ javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
+ self._jdf.persist(javaStorageLevel)
return self
def unpersist(self, blocking=True):
+ """ Mark it as non-persistent, and remove all blocks for it from
+ memory and disk.
+ """
self.is_cached = False
- self._jschema_rdd.unpersist(blocking)
+ self._jdf.unpersist(blocking)
return self
- def checkpoint(self):
- self.is_checkpointed = True
- self._jschema_rdd.checkpoint()
+ # def coalesce(self, numPartitions, shuffle=False):
+ # rdd = self._jdf.coalesce(numPartitions, shuffle, None)
+ # return DataFrame(rdd, self.sql_ctx)
+
+ def repartition(self, numPartitions):
+ """ Return a new :class:`DataFrame` that has exactly `numPartitions`
+ partitions.
+ """
+ rdd = self._jdf.repartition(numPartitions, None)
+ return DataFrame(rdd, self.sql_ctx)
+
+ def sample(self, withReplacement, fraction, seed=None):
+ """
+ Return a sampled subset of this DataFrame.
+
+ >>> df = sqlCtx.inferSchema(rdd)
+ >>> df.sample(False, 0.5, 97).count()
+ 2L
+ """
+ assert fraction >= 0.0, "Negative fraction value: %s" % fraction
+ seed = seed if seed is not None else random.randint(0, sys.maxint)
+ rdd = self._jdf.sample(withReplacement, fraction, long(seed))
+ return DataFrame(rdd, self.sql_ctx)
+
+ # def takeSample(self, withReplacement, num, seed=None):
+ # """Return a fixed-size sampled subset of this DataFrame.
+ #
+ # >>> df = sqlCtx.inferSchema(rdd)
+ # >>> df.takeSample(False, 2, 97)
+ # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
+ # """
+ # seed = seed if seed is not None else random.randint(0, sys.maxint)
+ # with SCCallSiteSync(self.context) as css:
+ # bytesInJava = self._jdf \
+ # .takeSampleToPython(withReplacement, num, long(seed)) \
+ # .iterator()
+ # cls = _create_cls(self.schema())
+ # return map(cls, self._collect_iterator_through_file(bytesInJava))
+
+ @property
+ def dtypes(self):
+ """Return all column names and their data types as a list.
+
+ >>> df.dtypes
+ [('age', 'integer'), ('name', 'string')]
+ """
+ return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
+
+ @property
+ def columns(self):
+ """ Return all column names as a list.
+
+ >>> df.columns
+ [u'age', u'name']
+ """
+ return [f.name for f in self.schema().fields]
- def isCheckpointed(self):
- return self._jschema_rdd.isCheckpointed()
+ def join(self, other, joinExprs=None, joinType=None):
+ """
+ Join with another DataFrame, using the given join expression.
+ The following performs a full outer join between `df1` and `df2`::
- def getCheckpointFile(self):
- checkpointFile = self._jschema_rdd.getCheckpointFile()
- if checkpointFile.isPresent():
- return checkpointFile.get()
+ :param other: Right side of the join
+ :param joinExprs: Join expression
+ :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
- def coalesce(self, numPartitions, shuffle=False):
- rdd = self._jschema_rdd.coalesce(numPartitions, shuffle)
- return SchemaRDD(rdd, self.sql_ctx)
+ >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
+ [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
+ """
- def distinct(self, numPartitions=None):
- if numPartitions is None:
- rdd = self._jschema_rdd.distinct()
+ if joinExprs is None:
+ jdf = self._jdf.join(other._jdf)
else:
- rdd = self._jschema_rdd.distinct(numPartitions)
- return SchemaRDD(rdd, self.sql_ctx)
+ assert isinstance(joinExprs, Column), "joinExprs should be Column"
+ if joinType is None:
+ jdf = self._jdf.join(other._jdf, joinExprs._jc)
+ else:
+ assert isinstance(joinType, basestring), "joinType should be basestring"
+ jdf = self._jdf.join(other._jdf, joinExprs._jc, joinType)
+ return DataFrame(jdf, self.sql_ctx)
+
+ def sort(self, *cols):
+ """ Return a new :class:`DataFrame` sorted by the specified column.
+
+ :param cols: The columns or expressions used for sorting
+
+ >>> df.sort(df.age.desc()).collect()
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+ >>> df.sortBy(df.age.desc()).collect()
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+ """
+ if not cols:
+ raise ValueError("should sort by at least one column")
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ self._sc._gateway._gateway_client)
+ jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
+ return DataFrame(jdf, self.sql_ctx)
+
+ sortBy = sort
+
+ def head(self, n=None):
+ """ Return the first `n` rows or the first row if n is None.
+
+ >>> df.head()
+ Row(age=2, name=u'Alice')
+ >>> df.head(1)
+ [Row(age=2, name=u'Alice')]
+ """
+ if n is None:
+ rs = self.head(1)
+ return rs[0] if rs else None
+ return self.take(n)
+
+ def first(self):
+ """ Return the first row.
+
+ >>> df.first()
+ Row(age=2, name=u'Alice')
+ """
+ return self.head()
+
+ def __getitem__(self, item):
+ """ Return the column by given name
- def intersection(self, other):
- if (other.__class__ is SchemaRDD):
- rdd = self._jschema_rdd.intersection(other._jschema_rdd)
- return SchemaRDD(rdd, self.sql_ctx)
+ >>> df['age'].collect()
+ [Row(age=2), Row(age=5)]
+ >>> df[ ["name", "age"]].collect()
+ [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+ >>> df[ df.age > 3 ].collect()
+ [Row(age=5, name=u'Bob')]
+ """
+ if isinstance(item, basestring):
+ jc = self._jdf.apply(item)
+ return Column(jc, self.sql_ctx)
+ elif isinstance(item, Column):
+ return self.filter(item)
+ elif isinstance(item, list):
+ return self.select(*item)
else:
- raise ValueError("Can only intersect with another SchemaRDD")
+ raise IndexError("unexpected index: %s" % item)
- def repartition(self, numPartitions):
- rdd = self._jschema_rdd.repartition(numPartitions)
- return SchemaRDD(rdd, self.sql_ctx)
+ def __getattr__(self, name):
+ """ Return the column by given name
- def subtract(self, other, numPartitions=None):
- if (other.__class__ is SchemaRDD):
- if numPartitions is None:
- rdd = self._jschema_rdd.subtract(other._jschema_rdd)
- else:
- rdd = self._jschema_rdd.subtract(other._jschema_rdd,
- numPartitions)
- return SchemaRDD(rdd, self.sql_ctx)
+ >>> df.age.collect()
+ [Row(age=2), Row(age=5)]
+ """
+ if name.startswith("__"):
+ raise AttributeError(name)
+ jc = self._jdf.apply(name)
+ return Column(jc, self.sql_ctx)
+
+ def select(self, *cols):
+ """ Selecting a set of expressions.
+
+ >>> df.select().collect()
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ >>> df.select('*').collect()
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ >>> df.select('name', 'age').collect()
+ [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
+ >>> df.select(df.name, (df.age + 10).alias('age')).collect()
+ [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
+ """
+ if not cols:
+ cols = ["*"]
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ self._sc._gateway._gateway_client)
+ jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ return DataFrame(jdf, self.sql_ctx)
+
+ def selectExpr(self, *expr):
+ """
+ Selects a set of SQL expressions. This is a variant of
+ `select` that accepts SQL expressions.
+
+ >>> df.selectExpr("age * 2", "abs(age)").collect()
+ [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
+ """
+ jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
+ jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
+ return DataFrame(jdf, self.sql_ctx)
+
+ def filter(self, condition):
+ """ Filtering rows using the given condition, which could be
+ Column expression or string of SQL expression.
+
+ where() is an alias for filter().
+
+ >>> df.filter(df.age > 3).collect()
+ [Row(age=5, name=u'Bob')]
+ >>> df.where(df.age == 2).collect()
+ [Row(age=2, name=u'Alice')]
+
+ >>> df.filter("age > 3").collect()
+ [Row(age=5, name=u'Bob')]
+ >>> df.where("age = 2").collect()
+ [Row(age=2, name=u'Alice')]
+ """
+ if isinstance(condition, basestring):
+ jdf = self._jdf.filter(condition)
+ elif isinstance(condition, Column):
+ jdf = self._jdf.filter(condition._jc)
else:
- raise ValueError("Can only subtract another SchemaRDD")
+ raise TypeError("condition should be string or Column")
+ return DataFrame(jdf, self.sql_ctx)
+
+ where = filter
+
+ def groupBy(self, *cols):
+ """ Group the :class:`DataFrame` using the specified columns,
+ so we can run aggregation on them. See :class:`GroupedData`
+ for all the available aggregate functions.
+
+ >>> df.groupBy().avg().collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df.groupBy('name').agg({'age': 'mean'}).collect()
+ [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
+ >>> df.groupBy(df.name).avg().collect()
+ [Row(name=u'Bob', AVG(age#0)=5.0), Row(name=u'Alice', AVG(age#0)=2.0)]
+ """
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ self._sc._gateway._gateway_client)
+ jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ return GroupedData(jdf, self.sql_ctx)
+
+ def agg(self, *exprs):
+ """ Aggregate on the entire :class:`DataFrame` without groups
+ (shorthand for df.groupBy.agg()).
+
+ >>> df.agg({"age": "max"}).collect()
+ [Row(MAX(age#0)=5)]
+ >>> from pyspark.sql import Dsl
+ >>> df.agg(Dsl.min(df.age)).collect()
+ [Row(MIN(age#0)=2)]
+ """
+ return self.groupBy().agg(*exprs)
+
+ def unionAll(self, other):
+ """ Return a new DataFrame containing union of rows in this
+ frame and another frame.
+
+ This is equivalent to `UNION ALL` in SQL.
+ """
+ return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx)
+
+ def intersect(self, other):
+ """ Return a new :class:`DataFrame` containing rows only in
+ both this frame and another frame.
+
+ This is equivalent to `INTERSECT` in SQL.
+ """
+ return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx)
+
+ def subtract(self, other):
+ """ Return a new :class:`DataFrame` containing rows in this frame
+ but not in another frame.
+
+ This is equivalent to `EXCEPT` in SQL.
+ """
+ return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
+
+ def addColumn(self, colName, col):
+ """ Return a new :class:`DataFrame` by adding a column.
+
+ >>> df.addColumn('age2', df.age + 2).collect()
+ [Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
+ """
+ return self.select('*', col.alias(colName))
+
+
+# Having SchemaRDD for backward compatibility (for docs)
+class SchemaRDD(DataFrame):
+ """
+ SchemaRDD is deprecated, please use DataFrame
+ """
+
+
+def dfapi(f):
+ def _api(self):
+ name = f.__name__
+ jdf = getattr(self._jdf, name)()
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
+class GroupedData(object):
+
+ """
+ A set of methods for aggregations on a :class:`DataFrame`,
+ created by DataFrame.groupBy().
+ """
+
+ def __init__(self, jdf, sql_ctx):
+ self._jdf = jdf
+ self.sql_ctx = sql_ctx
+
+ def agg(self, *exprs):
+ """ Compute aggregates by specifying a map from column name
+ to aggregate methods.
+
+ The available aggregate methods are `avg`, `max`, `min`,
+ `sum`, `count`.
+
+ :param exprs: list or aggregate columns or a map from column
+ name to aggregate methods.
+
+ >>> gdf = df.groupBy(df.name)
+ >>> gdf.agg({"age": "max"}).collect()
+ [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
+ >>> from pyspark.sql import Dsl
+ >>> gdf.agg(Dsl.min(df.age)).collect()
+ [Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
+ """
+ assert exprs, "exprs should not be empty"
+ if len(exprs) == 1 and isinstance(exprs[0], dict):
+ jmap = MapConverter().convert(exprs[0],
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.agg(jmap)
+ else:
+ # Columns
+ assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
+ jcols = ListConverter().convert([c._jc for c in exprs[1:]],
+ self.sql_ctx._sc._gateway._gateway_client)
+ jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
+ return DataFrame(jdf, self.sql_ctx)
+
+ @dfapi
+ def count(self):
+ """ Count the number of rows for each group.
+
+ >>> df.groupBy(df.age).count().collect()
+ [Row(age=2, count=1), Row(age=5, count=1)]
+ """
+
+ @dfapi
+ def mean(self):
+ """Compute the average value for each numeric columns
+ for each group. This is an alias for `avg`."""
+
+ @dfapi
+ def avg(self):
+ """Compute the average value for each numeric columns
+ for each group."""
+
+ @dfapi
+ def max(self):
+ """Compute the max value for each numeric columns for
+ each group. """
+
+ @dfapi
+ def min(self):
+ """Compute the min value for each numeric column for
+ each group."""
+
+ @dfapi
+ def sum(self):
+ """Compute the sum for each numeric columns for each
+ group."""
+
+
+def _create_column_from_literal(literal):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Dsl.lit(literal)
+
+
+def _create_column_from_name(name):
+ sc = SparkContext._active_spark_context
+ return sc._jvm.Dsl.col(name)
+
+
+def _to_java_column(col):
+ if isinstance(col, Column):
+ jcol = col._jc
+ else:
+ jcol = _create_column_from_name(col)
+ return jcol
+
+
+def _unary_op(name, doc="unary operator"):
+ """ Create a method for given unary operator """
+ def _(self):
+ jc = getattr(self._jc, name)()
+ return Column(jc, self.sql_ctx)
+ _.__doc__ = doc
+ return _
+
+
+def _dsl_op(name, doc=''):
+ def _(self):
+ jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
+ return Column(jc, self.sql_ctx)
+ _.__doc__ = doc
+ return _
+
+
+def _bin_op(name, doc="binary operator"):
+ """ Create a method for given binary operator
+ """
+ def _(self, other):
+ jc = other._jc if isinstance(other, Column) else other
+ njc = getattr(self._jc, name)(jc)
+ return Column(njc, self.sql_ctx)
+ _.__doc__ = doc
+ return _
+
+
+def _reverse_op(name, doc="binary operator"):
+ """ Create a method for binary operator (this object is on right side)
+ """
+ def _(self, other):
+ jother = _create_column_from_literal(other)
+ jc = getattr(jother, name)(self._jc)
+ return Column(jc, self.sql_ctx)
+ _.__doc__ = doc
+ return _
+
+
+class Column(DataFrame):
+
+ """
+ A column in a DataFrame.
+
+ `Column` instances can be created by::
+
+ # 1. Select a column out of a DataFrame
+ df.colName
+ df["colName"]
+
+ # 2. Create from an expression
+ df.colName + 1
+ 1 / df.colName
+ """
+
+ def __init__(self, jc, sql_ctx=None):
+ self._jc = jc
+ super(Column, self).__init__(jc, sql_ctx)
+
+ # arithmetic operators
+ __neg__ = _dsl_op("negate")
+ __add__ = _bin_op("plus")
+ __sub__ = _bin_op("minus")
+ __mul__ = _bin_op("multiply")
+ __div__ = _bin_op("divide")
+ __mod__ = _bin_op("mod")
+ __radd__ = _bin_op("plus")
+ __rsub__ = _reverse_op("minus")
+ __rmul__ = _bin_op("multiply")
+ __rdiv__ = _reverse_op("divide")
+ __rmod__ = _reverse_op("mod")
+
+ # logistic operators
+ __eq__ = _bin_op("equalTo")
+ __ne__ = _bin_op("notEqual")
+ __lt__ = _bin_op("lt")
+ __le__ = _bin_op("leq")
+ __ge__ = _bin_op("geq")
+ __gt__ = _bin_op("gt")
+
+ # `and`, `or`, `not` cannot be overloaded in Python,
+ # so use bitwise operators as boolean operators
+ __and__ = _bin_op('and')
+ __or__ = _bin_op('or')
+ __invert__ = _dsl_op('not')
+ __rand__ = _bin_op("and")
+ __ror__ = _bin_op("or")
+
+ # container operators
+ __contains__ = _bin_op("contains")
+ __getitem__ = _bin_op("getItem")
+ getField = _bin_op("getField", "An expression that gets a field by name in a StructField.")
+
+ # string methods
+ rlike = _bin_op("rlike")
+ like = _bin_op("like")
+ startswith = _bin_op("startsWith")
+ endswith = _bin_op("endsWith")
+
+ def substr(self, startPos, length):
+ """
+ Return a Column which is a substring of the column
+
+ :param startPos: start position (int or Column)
+ :param length: length of the substring (int or Column)
+
+ >>> df.name.substr(1, 3).collect()
+ [Row(col=u'Ali'), Row(col=u'Bob')]
+ """
+ if type(startPos) != type(length):
+ raise TypeError("Can not mix the type")
+ if isinstance(startPos, (int, long)):
+ jc = self._jc.substr(startPos, length)
+ elif isinstance(startPos, Column):
+ jc = self._jc.substr(startPos._jc, length._jc)
+ else:
+ raise TypeError("Unexpected type: %s" % type(startPos))
+ return Column(jc, self.sql_ctx)
+
+ __getslice__ = substr
+
+ # order
+ asc = _unary_op("asc")
+ desc = _unary_op("desc")
+
+ isNull = _unary_op("isNull", "True if the current expression is null.")
+ isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
+
+ def alias(self, alias):
+ """Return a alias for this column
+
+ >>> df.age.alias("age2").collect()
+ [Row(age2=2), Row(age2=5)]
+ """
+ return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
+
+ def cast(self, dataType):
+ """ Convert the column into type `dataType`
+
+ >>> df.select(df.age.cast("string").alias('ages')).collect()
+ [Row(ages=u'2'), Row(ages=u'5')]
+ >>> df.select(df.age.cast(StringType()).alias('ages')).collect()
+ [Row(ages=u'2'), Row(ages=u'5')]
+ """
+ if self.sql_ctx is None:
+ sc = SparkContext._active_spark_context
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ else:
+ ssql_ctx = self.sql_ctx._ssql_ctx
+ if isinstance(dataType, basestring):
+ jc = self._jc.cast(dataType)
+ elif isinstance(dataType, DataType):
+ jdt = ssql_ctx.parseDataType(dataType.json())
+ jc = self._jc.cast(jdt)
+ return Column(jc, self.sql_ctx)
+
+
+def _aggregate_func(name, doc=""):
+ """ Create a function for aggregator by name"""
+ def _(col):
+ sc = SparkContext._active_spark_context
+ jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
+ return Column(jc)
+ _.__name__ = name
+ _.__doc__ = doc
+ return staticmethod(_)
+
+
+class UserDefinedFunction(object):
+ def __init__(self, func, returnType):
+ self.func = func
+ self.returnType = returnType
+ self._broadcast = None
+ self._judf = self._create_judf()
+
+ def _create_judf(self):
+ f = self.func # put it in closure `func`
+ func = lambda _, it: imap(lambda x: f(*x), it)
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, None, ser, ser)
+ sc = SparkContext._active_spark_context
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ jdt = ssql_ctx.parseDataType(self.returnType.json())
+ judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+ includes, sc.pythonExec, broadcast_vars,
+ sc._javaAccumulator, jdt)
+ return judf
+
+ def __del__(self):
+ if self._broadcast is not None:
+ self._broadcast.unpersist()
+ self._broadcast = None
+
+ def __call__(self, *cols):
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+
+class Dsl(object):
+ """
+ A collections of builtin aggregators
+ """
+ DSLS = {
+ 'lit': 'Creates a :class:`Column` of literal value.',
+ 'col': 'Returns a :class:`Column` based on the given column name.',
+ 'column': 'Returns a :class:`Column` based on the given column name.',
+ 'upper': 'Converts a string expression to upper case.',
+ 'lower': 'Converts a string expression to upper case.',
+ 'sqrt': 'Computes the square root of the specified float value.',
+ 'abs': 'Computes the absolutle value.',
+
+ 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
+ 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
+ 'first': 'Aggregate function: returns the first value in a group.',
+ 'last': 'Aggregate function: returns the last value in a group.',
+ 'count': 'Aggregate function: returns the number of items in a group.',
+ 'sum': 'Aggregate function: returns the sum of all values in the expression.',
+ 'avg': 'Aggregate function: returns the average of the values in a group.',
+ 'mean': 'Aggregate function: returns the average of the values in a group.',
+ 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
+ }
+
+ for _name, _doc in DSLS.items():
+ locals()[_name] = _aggregate_func(_name, _doc)
+ del _name, _doc
+
+ @staticmethod
+ def countDistinct(col, *cols):
+ """ Return a new Column for distinct count of (col, *cols)
+
+ >>> from pyspark.sql import Dsl
+ >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
+ [Row(c=2)]
+
+ >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
+ sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+ @staticmethod
+ def approxCountDistinct(col, rsd=None):
+ """ Return a new Column for approxiate distinct count of (col, *cols)
+
+ >>> from pyspark.sql import Dsl
+ >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ if rsd is None:
+ jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
+ else:
+ jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
+ return Column(jc)
+
+ @staticmethod
+ def udf(f, returnType=StringType()):
+ """Create a user defined function (UDF)
+
+ >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
+ >>> df.select(slen(df.name).alias('slen')).collect()
+ [Row(slen=5), Row(slen=3)]
+ """
+ return UserDefinedFunction(f, returnType)
def _test():
import doctest
- from array import array
from pyspark.context import SparkContext
# let doctest run in pyspark.sql, so DataTypes can be picklable
import pyspark.sql
from pyspark.sql import Row, SQLContext
+ from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
globs = pyspark.sql.__dict__.copy()
- # The small batch size here ensures that we see multiple batches,
- # even in these small test examples:
- sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = SQLContext(sc)
+ globs['sqlCtx'] = sqlCtx = SQLContext(sc)
globs['rdd'] = sc.parallelize(
[Row(field1=1, field2="row1"),
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
+ rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
+ rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
+ globs['df'] = sqlCtx.inferSchema(rdd2)
+ globs['df2'] = sqlCtx.inferSchema(rdd3)
+ globs['ExamplePoint'] = ExamplePoint
+ globs['ExamplePointUDT'] = ExamplePointUDT
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
diff --git a/python/pyspark/sql_tests.py b/python/pyspark/sql_tests.py
new file mode 100644
index 0000000000000..d314f46e8d2d5
--- /dev/null
+++ b/python/pyspark/sql_tests.py
@@ -0,0 +1,299 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Unit tests for pyspark.sql; additional tests are implemented as doctests in
+individual modules.
+"""
+import os
+import sys
+import pydoc
+import shutil
+import tempfile
+
+if sys.version_info[:2] <= (2, 6):
+ try:
+ import unittest2 as unittest
+ except ImportError:
+ sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier')
+ sys.exit(1)
+else:
+ import unittest
+
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
+ UserDefinedType, DoubleType
+from pyspark.tests import ReusedPySparkTestCase
+
+
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return 'pyspark.tests'
+
+ @classmethod
+ def scalaUDT(cls):
+ return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Scala, Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, ExamplePoint) and \
+ other.x == self.x and other.y == self.y
+
+
+class SQLTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+ cls.sqlCtx = SQLContext(cls.sc)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ rdd = cls.sc.parallelize(cls.testData)
+ cls.df = cls.sqlCtx.inferSchema(rdd)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+ def test_udf(self):
+ self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
+ [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
+ self.assertEqual(row[0], 5)
+
+ def test_udf2(self):
+ self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
+ self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+ [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
+ self.assertEqual(4, res[0])
+
+ def test_udf_with_array_type(self):
+ d = [Row(l=range(3), d={"key": range(5)})]
+ rdd = self.sc.parallelize(d)
+ self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
+ self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
+ [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
+ self.assertEqual(range(3), l1)
+ self.assertEqual(1, l2)
+
+ def test_broadcast_in_udf(self):
+ bar = {"a": "aa", "b": "bb", "c": "abc"}
+ foo = self.sc.broadcast(bar)
+ self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
+ [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
+ self.assertEqual("abc", res[0])
+ [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
+ self.assertEqual("", res[0])
+
+ def test_basic_functions(self):
+ rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+ df = self.sqlCtx.jsonRDD(rdd)
+ df.count()
+ df.collect()
+ df.schema()
+
+ # cache and checkpoint
+ self.assertFalse(df.is_cached)
+ df.persist()
+ df.unpersist()
+ df.cache()
+ self.assertTrue(df.is_cached)
+ self.assertEqual(2, df.count())
+
+ df.registerTempTable("temp")
+ df = self.sqlCtx.sql("select foo from temp")
+ df.count()
+ df.collect()
+
+ def test_apply_schema_to_row(self):
+ df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+ df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ self.assertEqual(df.collect(), df2.collect())
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+ df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ self.assertEqual(10, df3.count())
+
+ def test_serialize_nested_array_and_map(self):
+ d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
+ rdd = self.sc.parallelize(d)
+ df = self.sqlCtx.inferSchema(rdd)
+ row = df.head()
+ self.assertEqual(1, len(row.l))
+ self.assertEqual(1, row.l[0].a)
+ self.assertEqual("2", row.d["key"].d)
+
+ l = df.map(lambda x: x.l).first()
+ self.assertEqual(1, len(l))
+ self.assertEqual('s', l[0].b)
+
+ d = df.map(lambda x: x.d).first()
+ self.assertEqual(1, len(d))
+ self.assertEqual(1.0, d["key"].c)
+
+ row = df.map(lambda x: x.d["key"]).first()
+ self.assertEqual(1.0, row.c)
+ self.assertEqual("2", row.d)
+
+ def test_infer_schema(self):
+ d = [Row(l=[], d={}),
+ Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+ rdd = self.sc.parallelize(d)
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEqual([], df.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
+ df.registerTempTable("test")
+ result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
+ self.assertEqual(1, result.head()[0])
+
+ df2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ self.assertEqual(df.schema(), df2.schema())
+ self.assertEqual({}, df2.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
+ df2.registerTempTable("test2")
+ result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+ self.assertEqual(1, result.head()[0])
+
+ def test_struct_in_map(self):
+ d = [Row(m={Row(i=1): Row(s="")})]
+ rdd = self.sc.parallelize(d)
+ df = self.sqlCtx.inferSchema(rdd)
+ k, v = df.head().m.items()[0]
+ self.assertEqual(1, k.i)
+ self.assertEqual("", v.s)
+
+ def test_convert_row_to_dict(self):
+ row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
+ self.assertEqual(1, row.asDict()['l'][0].a)
+ rdd = self.sc.parallelize([row])
+ df = self.sqlCtx.inferSchema(rdd)
+ df.registerTempTable("test")
+ row = self.sqlCtx.sql("select l, d from test").head()
+ self.assertEqual(1, row.asDict()["l"][0].a)
+ self.assertEqual(1.0, row.asDict()['d']['key'].c)
+
+ def test_infer_schema_with_udt(self):
+ from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ df = self.sqlCtx.inferSchema(rdd)
+ schema = df.schema()
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), ExamplePointUDT)
+ df.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ def test_apply_schema_with_udt(self):
+ from pyspark.sql_tests import ExamplePoint, ExamplePointUDT
+ row = (1.0, ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ df = self.sqlCtx.applySchema(rdd, schema)
+ point = df.head().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_parquet_with_udt(self):
+ from pyspark.sql_tests import ExamplePoint
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ df0 = self.sqlCtx.inferSchema(rdd)
+ output_dir = os.path.join(self.tempdir.name, "labeled_point")
+ df0.saveAsParquetFile(output_dir)
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ point = df1.head().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_column_operators(self):
+ from pyspark.sql import Column, LongType
+ ci = self.df.key
+ cs = self.df.value
+ c = ci == cs
+ self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column))
+ rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci)
+ self.assertTrue(all(isinstance(c, Column) for c in rcc))
+ cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs]
+ self.assertTrue(all(isinstance(c, Column) for c in cb))
+ cbool = (ci & ci), (ci | ci), (~ci)
+ self.assertTrue(all(isinstance(c, Column) for c in cbool))
+ css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a')
+ self.assertTrue(all(isinstance(c, Column) for c in css))
+ self.assertTrue(isinstance(ci.cast(LongType()), Column))
+
+ def test_column_select(self):
+ df = self.df
+ self.assertEqual(self.testData, df.select("*").collect())
+ self.assertEqual(self.testData, df.select(df.key, df.value).collect())
+ self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())
+
+ def test_aggregator(self):
+ df = self.df
+ g = df.groupBy()
+ self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
+ self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+
+ from pyspark.sql import Dsl
+ self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+
+ def test_help_command(self):
+ # Regression test for SPARK-5464
+ rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
+ df = self.sqlCtx.jsonRDD(rdd)
+ # render_doc() reproduces the help() exception without printing output
+ pydoc.render_doc(df)
+ pydoc.render_doc(df.foo)
+ pydoc.render_doc(df.take(1))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index 2f53fbd27b17a..b06ab650370bd 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -21,7 +21,7 @@
from py4j.java_gateway import java_import, JavaObject
from pyspark import RDD, SparkConf
-from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
+from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer
from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.dstream import DStream
@@ -142,8 +142,8 @@ def getOrCreate(cls, checkpointPath, setupFunc):
recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
will be used to create a JavaStreamingContext.
- @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
- @param setupFunc Function to create a new JavaStreamingContext and setup DStreams
+ @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program
+ @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams
"""
# TODO: support checkpoint in HDFS
if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
@@ -191,6 +191,15 @@ def awaitTermination(self, timeout=None):
else:
self._jssc.awaitTermination(int(timeout * 1000))
+ def awaitTerminationOrTimeout(self, timeout):
+ """
+ Wait for the execution to stop. Return `true` if it's stopped; or
+ throw the reported error during the execution; or `false` if the
+ waiting time elapsed before returning from the method.
+ @param timeout: time to wait in seconds
+ """
+ self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
+
def stop(self, stopSparkContext=True, stopGraceFully=False):
"""
Stop the execution of the streams, with option of ensuring all
@@ -251,6 +260,20 @@ def textFileStream(self, directory):
"""
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
+ def binaryRecordsStream(self, directory, recordLength):
+ """
+ Create an input stream that monitors a Hadoop-compatible file system
+ for new files and reads them as flat binary files with records of
+ fixed length. Files must be written to the monitored directory by "moving"
+ them from another location within the same file system.
+ File names starting with . are ignored.
+
+ @param directory: Directory to load data from
+ @param recordLength: Length of each record in bytes
+ """
+ return DStream(self._jssc.binaryRecordsStream(directory, recordLength), self,
+ NoOpSerializer())
+
def _check_serializers(self, rdds):
# make sure they have same serializer
if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
index 0826ddc56e844..2fe39392ff081 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -157,18 +157,20 @@ def foreachRDD(self, func):
api = self._ssc._jvm.PythonDStream
api.callForeachRDD(self._jdstream, jfunc)
- def pprint(self):
+ def pprint(self, num=10):
"""
- Print the first ten elements of each RDD generated in this DStream.
+ Print the first num elements of each RDD generated in this DStream.
+
+ @param num: the number of elements from the first will be printed.
"""
def takeAndPrint(time, rdd):
- taken = rdd.take(11)
+ taken = rdd.take(num + 1)
print "-------------------------------------------"
print "Time: %s" % time
print "-------------------------------------------"
- for record in taken[:10]:
+ for record in taken[:num]:
print record
- if len(taken) > 10:
+ if len(taken) > num:
print "..."
print
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
new file mode 100644
index 0000000000000..19ad71f99d4d5
--- /dev/null
+++ b/python/pyspark/streaming/kafka.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from py4j.java_collections import MapConverter
+from py4j.java_gateway import java_import, Py4JError
+
+from pyspark.storagelevel import StorageLevel
+from pyspark.serializers import PairDeserializer, NoOpSerializer
+from pyspark.streaming import DStream
+
+__all__ = ['KafkaUtils', 'utf8_decoder']
+
+
+def utf8_decoder(s):
+ """ Decode the unicode as UTF-8 """
+ return s and s.decode('utf-8')
+
+
+class KafkaUtils(object):
+
+ @staticmethod
+ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
+ storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
+ keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
+ """
+ Create an input stream that pulls messages from a Kafka Broker.
+
+ :param ssc: StreamingContext object
+ :param zkQuorum: Zookeeper quorum (hostname:port,hostname:port,..).
+ :param groupId: The group id for this consumer.
+ :param topics: Dict of (topic_name -> numPartitions) to consume.
+ Each partition is consumed in its own thread.
+ :param kafkaParams: Additional params for Kafka
+ :param storageLevel: RDD storage level.
+ :param keyDecoder: A function used to decode key (default is utf8_decoder)
+ :param valueDecoder: A function used to decode value (default is utf8_decoder)
+ :return: A DStream object
+ """
+ java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils")
+
+ kafkaParams.update({
+ "zookeeper.connect": zkQuorum,
+ "group.id": groupId,
+ "zookeeper.connection.timeout.ms": "10000",
+ })
+ if not isinstance(topics, dict):
+ raise TypeError("topics should be dict")
+ jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client)
+ jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
+ jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
+
+ def getClassByName(name):
+ return ssc._jvm.org.apache.spark.util.Utils.classForName(name)
+
+ try:
+ array = getClassByName("[B")
+ decoder = getClassByName("kafka.serializer.DefaultDecoder")
+ jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder,
+ jparam, jtopics, jlevel)
+ except Py4JError, e:
+ # TODO: use --jar once it also work on driver
+ if not e.message or 'call a package' in e.message:
+ print "No kafka package, please put the assembly jar into classpath:"
+ print " $ bin/spark-submit --driver-class-path external/kafka-assembly/target/" + \
+ "scala-*/spark-streaming-kafka-assembly-*.jar"
+ raise e
+ ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
+ stream = DStream(jstream, ssc, ser)
+ return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v)))
diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py
index a8d876d0fa3b3..608f8e26473a6 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -21,6 +21,7 @@
import operator
import unittest
import tempfile
+import struct
from pyspark.context import SparkConf, SparkContext, RDD
from pyspark.streaming.context import StreamingContext
@@ -455,6 +456,20 @@ def test_text_file_stream(self):
self.wait_for(result, 2)
self.assertEqual([range(10), range(10)], result)
+ def test_binary_records_stream(self):
+ d = tempfile.mkdtemp()
+ self.ssc = StreamingContext(self.sc, self.duration)
+ dstream = self.ssc.binaryRecordsStream(d, 10).map(
+ lambda v: struct.unpack("10b", str(v)))
+ result = self._collect(dstream, 2, block=False)
+ self.ssc.start()
+ for name in ('a', 'b'):
+ time.sleep(1)
+ with open(os.path.join(d, name), "wb") as f:
+ f.write(bytearray(range(10)))
+ self.wait_for(result, 2)
+ self.assertEqual([range(10), range(10)], map(lambda v: list(v[0]), result))
+
def test_union(self):
input = [range(i + 1) for i in range(3)]
dstream = self.ssc.queueStream(input)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1a8e4150e63c3..b5e28c498040b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -31,7 +31,8 @@
import time
import zipfile
import random
-from platform import python_implementation
+import threading
+import hashlib
if sys.version_info[:2] <= (2, 6):
try:
@@ -45,12 +46,13 @@
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
+from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
- CloudPickleSerializer
+ CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row
from pyspark import shuffle
+from pyspark.profiler import BasicProfiler
_have_scipy = False
_have_numpy = False
@@ -235,13 +237,24 @@ def foo():
self.assertTrue("exit" in foo.func_code.co_names)
ser.dumps(foo)
+ def test_compressed_serializer(self):
+ ser = CompressedSerializer(PickleSerializer())
+ from StringIO import StringIO
+ io = StringIO()
+ ser.dump_stream(["abc", u"123", range(5)], io)
+ io.seek(0)
+ self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
+ ser.dump_stream(range(1000), io)
+ io.seek(0)
+ self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io)))
+
class PySparkTestCase(unittest.TestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
- self.sc = SparkContext('local[4]', class_name, batchSize=2)
+ self.sc = SparkContext('local[4]', class_name)
def tearDown(self):
self.sc.stop()
@@ -252,7 +265,7 @@ class ReusedPySparkTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2)
+ cls.sc = SparkContext('local[4]', cls.__name__)
@classmethod
def tearDownClass(cls):
@@ -439,7 +452,7 @@ def test_sampling_default_seed(self):
subset = data.takeSample(False, 10)
self.assertEqual(len(subset), 10)
- def testAggregateByKey(self):
+ def test_aggregate_by_key(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
def seqOp(x, y):
@@ -477,6 +490,32 @@ def test_large_broadcast(self):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)
+ def test_multiple_broadcasts(self):
+ N = 1 << 21
+ b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
+ r = range(1 << 15)
+ random.shuffle(r)
+ s = str(r)
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
+ random.shuffle(r)
+ s = str(r)
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
@@ -494,6 +533,15 @@ def test_zip_with_different_serializers(self):
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
b = b._reserialize(MarshalSerializer())
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
+ # regression test for SPARK-4841
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ t = self.sc.textFile(path)
+ cnt = t.count()
+ self.assertEqual(cnt, t.zip(t).count())
+ rdd = t.map(str)
+ self.assertEqual(cnt, t.zip(rdd).count())
+ # regression test for bug in _reserializer()
+ self.assertEqual(cnt, t.zip(rdd).count())
def test_zip_with_different_number_of_items(self):
a = self.sc.parallelize(range(5), 2)
@@ -648,6 +696,50 @@ def test_distinct(self):
self.assertEquals(result.getNumPartitions(), 5)
self.assertEquals(result.count(), 3)
+ def test_sort_on_empty_rdd(self):
+ self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
+
+ def test_sample(self):
+ rdd = self.sc.parallelize(range(0, 100), 4)
+ wo = rdd.sample(False, 0.1, 2).collect()
+ wo_dup = rdd.sample(False, 0.1, 2).collect()
+ self.assertSetEqual(set(wo), set(wo_dup))
+ wr = rdd.sample(True, 0.2, 5).collect()
+ wr_dup = rdd.sample(True, 0.2, 5).collect()
+ self.assertSetEqual(set(wr), set(wr_dup))
+ wo_s10 = rdd.sample(False, 0.3, 10).collect()
+ wo_s20 = rdd.sample(False, 0.3, 20).collect()
+ self.assertNotEqual(set(wo_s10), set(wo_s20))
+ wr_s11 = rdd.sample(True, 0.4, 11).collect()
+ wr_s21 = rdd.sample(True, 0.4, 21).collect()
+ self.assertNotEqual(set(wr_s11), set(wr_s21))
+
+ def test_null_in_rdd(self):
+ jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
+ rdd = RDD(jrdd, self.sc, UTF8Deserializer())
+ self.assertEqual([u"a", None, u"b"], rdd.collect())
+ rdd = RDD(jrdd, self.sc, NoOpSerializer())
+ self.assertEqual(["a", None, "b"], rdd.collect())
+
+ def test_multiple_python_java_RDD_conversions(self):
+ # Regression test for SPARK-5361
+ data = [
+ (u'1', {u'director': u'David Lean'}),
+ (u'2', {u'director': u'Andrew Dominik'})
+ ]
+ from pyspark.rdd import RDD
+ data_rdd = self.sc.parallelize(data)
+ data_java_rdd = data_rdd._to_java_object_rdd()
+ data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd)
+ converted_rdd = RDD(data_python_rdd, self.sc)
+ self.assertEqual(2, converted_rdd.count())
+
+ # conversion between python and java RDD threw exceptions
+ data_java_rdd = converted_rdd._to_java_object_rdd()
+ data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd)
+ converted_rdd = RDD(data_python_rdd, self.sc)
+ self.assertEqual(2, converted_rdd.count())
+
class ProfilerTests(PySparkTestCase):
@@ -655,19 +747,15 @@ def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
conf = SparkConf().set("spark.python.profile", "true")
- self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf)
+ self.sc = SparkContext('local[4]', class_name, conf=conf)
def test_profiler(self):
+ self.do_computation()
- def heavy_foo(x):
- for i in range(1 << 20):
- x = 1
- rdd = self.sc.parallelize(range(100))
- rdd.foreach(heavy_foo)
- profiles = self.sc._profile_stats
- self.assertEqual(1, len(profiles))
- id, acc, _ = profiles[0]
- stats = acc.value
+ profilers = self.sc.profiler_collector.profilers
+ self.assertEqual(1, len(profilers))
+ id, profiler, _ = profilers[0]
+ stats = profiler.stats()
self.assertTrue(stats is not None)
width, stat_list = stats.get_print_list([])
func_names = [func_name for fname, n, func_name in stat_list]
@@ -678,98 +766,30 @@ def heavy_foo(x):
self.sc.dump_profiles(d)
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
+ def test_custom_profiler(self):
+ class TestCustomProfiler(BasicProfiler):
+ def show(self, id):
+ self.result = "Custom formatting"
-class SQLTests(ReusedPySparkTestCase):
+ self.sc.profiler_collector.profiler_cls = TestCustomProfiler
- def setUp(self):
- self.sqlCtx = SQLContext(self.sc)
-
- def test_udf(self):
- self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
- [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
- self.assertEqual(row[0], 5)
-
- def test_udf2(self):
- self.sqlCtx.registerFunction("strlen", lambda string: len(string))
- self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
- [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
- self.assertEqual(u"4", res[0])
-
- def test_broadcast_in_udf(self):
- bar = {"a": "aa", "b": "bb", "c": "abc"}
- foo = self.sc.broadcast(bar)
- self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
- [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
- self.assertEqual("abc", res[0])
- [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
- self.assertEqual("", res[0])
-
- def test_basic_functions(self):
- rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
- srdd = self.sqlCtx.jsonRDD(rdd)
- srdd.count()
- srdd.collect()
- srdd.schemaString()
- srdd.schema()
-
- # cache and checkpoint
- self.assertFalse(srdd.is_cached)
- srdd.persist()
- srdd.unpersist()
- srdd.cache()
- self.assertTrue(srdd.is_cached)
- self.assertFalse(srdd.isCheckpointed())
- self.assertEqual(None, srdd.getCheckpointFile())
-
- srdd = srdd.coalesce(2, True)
- srdd = srdd.repartition(3)
- srdd = srdd.distinct()
- srdd.intersection(srdd)
- self.assertEqual(2, srdd.count())
-
- srdd.registerTempTable("temp")
- srdd = self.sqlCtx.sql("select foo from temp")
- srdd.count()
- srdd.collect()
+ self.do_computation()
- def test_distinct(self):
- rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10)
- srdd = self.sqlCtx.jsonRDD(rdd)
- self.assertEquals(srdd.getNumPartitions(), 10)
- self.assertEquals(srdd.distinct().count(), 3)
- result = srdd.distinct(5)
- self.assertEquals(result.getNumPartitions(), 5)
- self.assertEquals(result.count(), 3)
+ profilers = self.sc.profiler_collector.profilers
+ self.assertEqual(1, len(profilers))
+ _, profiler, _ = profilers[0]
+ self.assertTrue(isinstance(profiler, TestCustomProfiler))
- def test_apply_schema_to_row(self):
- srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema())
- self.assertEqual(srdd.collect(), srdd2.collect())
-
- rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
- self.assertEqual(10, srdd3.count())
-
- def test_serialize_nested_array_and_map(self):
- d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
- rdd = self.sc.parallelize(d)
- srdd = self.sqlCtx.inferSchema(rdd)
- row = srdd.first()
- self.assertEqual(1, len(row.l))
- self.assertEqual(1, row.l[0].a)
- self.assertEqual("2", row.d["key"].d)
-
- l = srdd.map(lambda x: x.l).first()
- self.assertEqual(1, len(l))
- self.assertEqual('s', l[0].b)
+ self.sc.show_profiles()
+ self.assertEqual("Custom formatting", profiler.result)
- d = srdd.map(lambda x: x.d).first()
- self.assertEqual(1, len(d))
- self.assertEqual(1.0, d["key"].c)
+ def do_computation(self):
+ def heavy_foo(x):
+ for i in range(1 << 20):
+ x = 1
- row = srdd.map(lambda x: x.d["key"]).first()
- self.assertEqual(1.0, row.c)
- self.assertEqual("2", row.d)
+ rdd = self.sc.parallelize(range(100))
+ rdd.foreach(heavy_foo)
class InputFormatTests(ReusedPySparkTestCase):
@@ -868,16 +888,19 @@ def test_sequencefiles(self):
clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
"org.apache.hadoop.io.Text",
"org.apache.spark.api.python.TestWritable").collect())
- ec = (u'1',
- {u'__class__': u'org.apache.spark.api.python.TestWritable',
- u'double': 54.0, u'int': 123, u'str': u'test1'})
- self.assertEqual(clazz[0], ec)
+ cname = u'org.apache.spark.api.python.TestWritable'
+ ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}),
+ (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}),
+ (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}),
+ (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}),
+ (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})]
+ self.assertEqual(clazz, ec)
unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
"org.apache.hadoop.io.Text",
"org.apache.spark.api.python.TestWritable",
- batchSize=1).collect())
- self.assertEqual(unbatched_clazz[0], ec)
+ ).collect())
+ self.assertEqual(unbatched_clazz, ec)
def test_oldhadoop(self):
basepath = self.tempdir.name
@@ -963,6 +986,25 @@ def test_converters(self):
(u'\x03', [2.0])]
self.assertEqual(maps, em)
+ def test_binary_files(self):
+ path = os.path.join(self.tempdir.name, "binaryfiles")
+ os.mkdir(path)
+ data = "short binary data"
+ with open(os.path.join(path, "part-0000"), 'w') as f:
+ f.write(data)
+ [(p, d)] = self.sc.binaryFiles(path).collect()
+ self.assertTrue(p.endswith("part-0000"))
+ self.assertEqual(d, data)
+
+ def test_binary_records(self):
+ path = os.path.join(self.tempdir.name, "binaryrecords")
+ os.mkdir(path)
+ with open(os.path.join(path, "part-0000"), 'w') as f:
+ for i in range(100):
+ f.write('%04d' % i)
+ result = self.sc.binaryRecords(path, 4).map(int).collect()
+ self.assertEqual(range(100), result)
+
class OutputFormatTests(ReusedPySparkTestCase):
@@ -1197,51 +1239,6 @@ def test_reserialization(self):
result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect())
self.assertEqual(result5, data)
- def test_unbatched_save_and_read(self):
- basepath = self.tempdir.name
- ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
- self.sc.parallelize(ei, len(ei)).saveAsSequenceFile(
- basepath + "/unbatched/")
-
- unbatched_sequence = sorted(self.sc.sequenceFile(
- basepath + "/unbatched/",
- batchSize=1).collect())
- self.assertEqual(unbatched_sequence, ei)
-
- unbatched_hadoopFile = sorted(self.sc.hadoopFile(
- basepath + "/unbatched/",
- "org.apache.hadoop.mapred.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text",
- batchSize=1).collect())
- self.assertEqual(unbatched_hadoopFile, ei)
-
- unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile(
- basepath + "/unbatched/",
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text",
- batchSize=1).collect())
- self.assertEqual(unbatched_newAPIHadoopFile, ei)
-
- oldconf = {"mapred.input.dir": basepath + "/unbatched/"}
- unbatched_hadoopRDD = sorted(self.sc.hadoopRDD(
- "org.apache.hadoop.mapred.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text",
- conf=oldconf,
- batchSize=1).collect())
- self.assertEqual(unbatched_hadoopRDD, ei)
-
- newconf = {"mapred.input.dir": basepath + "/unbatched/"}
- unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD(
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text",
- conf=newconf,
- batchSize=1).collect())
- self.assertEqual(unbatched_newAPIHadoopRDD, ei)
-
def test_malformed_RDD(self):
basepath = self.tempdir.name
# non-batch-serialized RDD[[(K, V)]] should be rejected
@@ -1380,6 +1377,23 @@ def test_accumulator_when_reuse_worker(self):
self.assertEqual(sum(range(100)), acc2.value)
self.assertEqual(sum(range(100)), acc1.value)
+ def test_reuse_worker_after_take(self):
+ rdd = self.sc.parallelize(range(100000), 1)
+ self.assertEqual(0, rdd.first())
+
+ def count():
+ try:
+ rdd.count()
+ except Exception:
+ pass
+
+ t = threading.Thread(target=count)
+ t.daemon = True
+ t.start()
+ t.join(5)
+ self.assertTrue(not t.isAlive())
+ self.assertEqual(100000, rdd.count())
+
class SparkSubmitTests(unittest.TestCase):
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 8257dddfee1c3..8a93c320ec5d3 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,15 +23,12 @@
import time
import socket
import traceback
-import cProfile
-import pstats
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
- CompressedSerializer
+ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -57,7 +54,7 @@ def main(infile, outfile):
boot_time = time.time()
split_index = read_int(infile)
if split_index == -1: # for unit tests
- return
+ exit(-1)
# initialize global state
shuffle.MemoryBytesSpilled = 0
@@ -78,12 +75,11 @@ def main(infile, outfile):
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
- ser = CompressedSerializer(pickleSer)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
if bid >= 0:
- value = ser._read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, value)
+ path = utf8_deserializer.loads(infile)
+ _broadcastRegistry[bid] = Broadcast(path=path)
else:
bid = - bid - 1
_broadcastRegistry.pop(bid)
@@ -92,26 +88,21 @@ def main(infile, outfile):
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
- (func, stats, deserializer, serializer) = command
+ (func, profiler, deserializer, serializer) = command
init_time = time.time()
def process():
iterator = deserializer.load_stream(infile)
serializer.dump_stream(func(split_index, iterator), outfile)
- if stats:
- p = cProfile.Profile()
- p.runcall(process)
- st = pstats.Stats(p)
- st.stream = None # make it picklable
- stats.add(st.strip_dirs())
+ if profiler:
+ profiler.profile(process)
else:
process()
except Exception:
try:
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
- outfile.flush()
except IOError:
# JVM close the socket
pass
@@ -131,6 +122,14 @@ def process():
for (aid, accum) in _accumulatorRegistry.items():
pickleSer._write_with_length((aid, accum._value), outfile)
+ # check end of stream
+ if read_int(infile) == SpecialLengths.END_OF_STREAM:
+ write_int(SpecialLengths.END_OF_STREAM, outfile)
+ else:
+ # write a different value to tell JVM to not reuse this worker
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ exit(-1)
+
if __name__ == '__main__':
# Read a local port to connect to from stdin
diff --git a/python/run-tests b/python/run-tests
index 80acd002ab7eb..649a2c44d187b 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -41,7 +41,7 @@ function run_test() {
# Fail and exit on the first test failure.
if [[ $FAILED != 0 ]]; then
- cat unit-tests.log | grep -v "^[0-9][0-9]*" # filter all lines starting with a number.
+ cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number.
echo -en "\033[31m" # Red
echo "Had test failures; see logs."
echo -en "\033[0m" # No color
@@ -56,7 +56,8 @@ function run_core_tests() {
run_test "pyspark/conf.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
- PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py"
+ run_test "pyspark/serializers.py"
+ run_test "pyspark/profiler.py"
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
}
@@ -64,6 +65,7 @@ function run_core_tests() {
function run_sql_tests() {
echo "Run sql tests ..."
run_test "pyspark/sql.py"
+ run_test "pyspark/sql_tests.py"
}
function run_mllib_tests() {
@@ -72,22 +74,29 @@ function run_mllib_tests() {
run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/feature.py"
run_test "pyspark/mllib/linalg.py"
- run_test "pyspark/mllib/random.py"
+ run_test "pyspark/mllib/rand.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
- run_test "pyspark/mllib/stat.py"
+ run_test "pyspark/mllib/stat/_statistics.py"
run_test "pyspark/mllib/tree.py"
run_test "pyspark/mllib/util.py"
run_test "pyspark/mllib/tests.py"
}
+function run_ml_tests() {
+ echo "Run ml tests ..."
+ run_test "pyspark/ml/feature.py"
+ run_test "pyspark/ml/classification.py"
+ run_test "pyspark/ml/tests.py"
+}
+
function run_streaming_tests() {
echo "Run streaming tests ..."
run_test "pyspark/streaming/util.py"
run_test "pyspark/streaming/tests.py"
}
-echo "Running PySpark tests. Output is in python/unit-tests.log."
+echo "Running PySpark tests. Output is in python/$LOG_FILE."
export PYSPARK_PYTHON="python"
@@ -102,6 +111,7 @@ $PYSPARK_PYTHON --version
run_core_tests
run_sql_tests
run_mllib_tests
+run_ml_tests
run_streaming_tests
# Try to test with PyPy
diff --git a/repl/pom.xml b/repl/pom.xml
index af528c8914335..bd39b90fd8714 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.3.0-SNAPSHOT../pom.xml
@@ -35,9 +35,16 @@
repl/usr/share/sparkroot
+ scala-2.10/src/main/scala
+ scala-2.10/src/test/scala
+
+ ${jline.groupid}
+ jline
+ ${jline.version}
+ org.apache.sparkspark-core_${scala.binary.version}
@@ -61,10 +68,6 @@
${project.version}test
-
- org.eclipse.jetty
- jetty-server
- org.scala-langscala-compiler
@@ -75,53 +78,89 @@
scala-reflect${scala.version}
-
- org.scala-lang
- jline
- ${scala.version}
- org.slf4jjul-to-slf4j
-
- org.scalatest
- scalatest_${scala.binary.version}
- test
- org.scalacheckscalacheck_${scala.binary.version}test
+
+
+
+ org.eclipse.jetty
+ jetty-server
+
+
+ org.eclipse.jetty
+ jetty-plus
+
+
+ org.eclipse.jetty
+ jetty-util
+
+
+ org.eclipse.jetty
+ jetty-http
+
+
+
+
+ org.scala-lang
+ scala-library
+ target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
+
- org.apache.maven.plugins
- maven-deploy-plugin
-
- true
-
-
-
- org.apache.maven.plugins
- maven-install-plugin
-
- true
-
-
-
- org.scalatest
- scalatest-maven-plugin
-
-
- ${basedir}/..
-
-
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-scala-sources
+ generate-sources
+
+ add-source
+
+
+
+
+
+
+
+
+
+ add-scala-test-sources
+ generate-test-sources
+
+ add-test-source
+
+
+
+
+
+
+
+
+
+
+
+ scala-2.11
+
+ scala-2.11
+
+
+ scala-2.11/src/main/scala
+ scala-2.11/src/test/scala
+
+
+
diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/Main.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
similarity index 83%
rename from repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
index 05816941b54b3..6480e2d24e044 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
@@ -19,14 +19,21 @@ package org.apache.spark.repl
import scala.tools.nsc.{Settings, CompilerCommand}
import scala.Predef._
+import org.apache.spark.annotation.DeveloperApi
/**
* Command class enabling Spark-specific command line options (provided by
* org.apache.spark.repl.SparkRunnerSettings).
+ *
+ * @example new SparkCommandLine(Nil).settings
+ *
+ * @param args The list of command line arguments
+ * @param settings The underlying settings to associate with this set of
+ * command-line options
*/
+@DeveloperApi
class SparkCommandLine(args: List[String], override val settings: Settings)
extends CompilerCommand(args, settings) {
-
def this(args: List[String], error: String => Unit) {
this(args, new SparkRunnerSettings(error))
}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
new file mode 100644
index 0000000000000..5fb378112ef92
--- /dev/null
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
@@ -0,0 +1,114 @@
+// scalastyle:off
+
+/* NSC -- new Scala compiler
+ * Copyright 2005-2013 LAMP/EPFL
+ * @author Paul Phillips
+ */
+
+package org.apache.spark.repl
+
+import scala.tools.nsc._
+import scala.tools.nsc.interpreter._
+
+import scala.reflect.internal.util.BatchSourceFile
+import scala.tools.nsc.ast.parser.Tokens.EOF
+
+import org.apache.spark.Logging
+
+private[repl] trait SparkExprTyper extends Logging {
+ val repl: SparkIMain
+
+ import repl._
+ import global.{ reporter => _, Import => _, _ }
+ import definitions._
+ import syntaxAnalyzer.{ UnitParser, UnitScanner, token2name }
+ import naming.freshInternalVarName
+
+ object codeParser extends { val global: repl.global.type = repl.global } with CodeHandlers[Tree] {
+ def applyRule[T](code: String, rule: UnitParser => T): T = {
+ reporter.reset()
+ val scanner = newUnitParser(code)
+ val result = rule(scanner)
+
+ if (!reporter.hasErrors)
+ scanner.accept(EOF)
+
+ result
+ }
+
+ def defns(code: String) = stmts(code) collect { case x: DefTree => x }
+ def expr(code: String) = applyRule(code, _.expr())
+ def stmts(code: String) = applyRule(code, _.templateStats())
+ def stmt(code: String) = stmts(code).last // guaranteed nonempty
+ }
+
+ /** Parse a line into a sequence of trees. Returns None if the input is incomplete. */
+ def parse(line: String): Option[List[Tree]] = debugging(s"""parse("$line")""") {
+ var isIncomplete = false
+ reporter.withIncompleteHandler((_, _) => isIncomplete = true) {
+ val trees = codeParser.stmts(line)
+ if (reporter.hasErrors) {
+ Some(Nil)
+ } else if (isIncomplete) {
+ None
+ } else {
+ Some(trees)
+ }
+ }
+ }
+ // def parsesAsExpr(line: String) = {
+ // import codeParser._
+ // (opt expr line).isDefined
+ // }
+
+ def symbolOfLine(code: String): Symbol = {
+ def asExpr(): Symbol = {
+ val name = freshInternalVarName()
+ // Typing it with a lazy val would give us the right type, but runs
+ // into compiler bugs with things like existentials, so we compile it
+ // behind a def and strip the NullaryMethodType which wraps the expr.
+ val line = "def " + name + " = {\n" + code + "\n}"
+
+ interpretSynthetic(line) match {
+ case IR.Success =>
+ val sym0 = symbolOfTerm(name)
+ // drop NullaryMethodType
+ val sym = sym0.cloneSymbol setInfo afterTyper(sym0.info.finalResultType)
+ if (sym.info.typeSymbol eq UnitClass) NoSymbol else sym
+ case _ => NoSymbol
+ }
+ }
+ def asDefn(): Symbol = {
+ val old = repl.definedSymbolList.toSet
+
+ interpretSynthetic(code) match {
+ case IR.Success =>
+ repl.definedSymbolList filterNot old match {
+ case Nil => NoSymbol
+ case sym :: Nil => sym
+ case syms => NoSymbol.newOverloaded(NoPrefix, syms)
+ }
+ case _ => NoSymbol
+ }
+ }
+ beQuietDuring(asExpr()) orElse beQuietDuring(asDefn())
+ }
+
+ private var typeOfExpressionDepth = 0
+ def typeOfExpression(expr: String, silent: Boolean = true): Type = {
+ if (typeOfExpressionDepth > 2) {
+ logDebug("Terminating typeOfExpression recursion for expression: " + expr)
+ return NoType
+ }
+ typeOfExpressionDepth += 1
+ // Don't presently have a good way to suppress undesirable success output
+ // while letting errors through, so it is first trying it silently: if there
+ // is an error, and errors are desired, then it re-evaluates non-silently
+ // to induce the error message.
+ try beSilentDuring(symbolOfLine(expr).tpe) match {
+ case NoType if !silent => symbolOfLine(expr).tpe // generate error
+ case tpe => tpe
+ }
+ finally typeOfExpressionDepth -= 1
+ }
+}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala
new file mode 100644
index 0000000000000..955be17a73b85
--- /dev/null
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package scala.tools.nsc
+
+import org.apache.spark.annotation.DeveloperApi
+
+// NOTE: Forced to be public (and in scala.tools.nsc package) to access the
+// settings "explicitParentLoader" method
+
+/**
+ * Provides exposure for the explicitParentLoader method on settings instances.
+ */
+@DeveloperApi
+object SparkHelper {
+ /**
+ * Retrieves the explicit parent loader for the provided settings.
+ *
+ * @param settings The settings whose explicit parent loader to retrieve
+ *
+ * @return The Optional classloader representing the explicit parent loader
+ */
+ @DeveloperApi
+ def explicitParentLoader(settings: Settings) = settings.explicitParentLoader
+}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
new file mode 100644
index 0000000000000..72c1a989999b4
--- /dev/null
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -0,0 +1,1119 @@
+// scalastyle:off
+
+/* NSC -- new Scala compiler
+ * Copyright 2005-2013 LAMP/EPFL
+ * @author Alexander Spoon
+ */
+
+package org.apache.spark.repl
+
+
+import java.net.URL
+
+import org.apache.spark.annotation.DeveloperApi
+
+import scala.reflect.io.AbstractFile
+import scala.tools.nsc._
+import scala.tools.nsc.backend.JavaPlatform
+import scala.tools.nsc.interpreter._
+
+import scala.tools.nsc.interpreter.{Results => IR}
+import Predef.{println => _, _}
+import java.io.{BufferedReader, FileReader}
+import java.net.URI
+import java.util.concurrent.locks.ReentrantLock
+import scala.sys.process.Process
+import scala.tools.nsc.interpreter.session._
+import scala.util.Properties.{jdkHome, javaVersion}
+import scala.tools.util.{Javap}
+import scala.annotation.tailrec
+import scala.collection.mutable.ListBuffer
+import scala.concurrent.ops
+import scala.tools.nsc.util._
+import scala.tools.nsc.interpreter._
+import scala.tools.nsc.io.{File, Directory}
+import scala.reflect.NameTransformer._
+import scala.tools.nsc.util.ScalaClassLoader._
+import scala.tools.util._
+import scala.language.{implicitConversions, existentials, postfixOps}
+import scala.reflect.{ClassTag, classTag}
+import scala.tools.reflect.StdRuntimeTags._
+
+import java.lang.{Class => jClass}
+import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse}
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkConf
+import org.apache.spark.SparkContext
+import org.apache.spark.util.Utils
+
+/** The Scala interactive shell. It provides a read-eval-print loop
+ * around the Interpreter class.
+ * After instantiation, clients should call the main() method.
+ *
+ * If no in0 is specified, then input will come from the console, and
+ * the class will attempt to provide input editing feature such as
+ * input history.
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ * @version 1.2
+ */
+@DeveloperApi
+class SparkILoop(
+ private val in0: Option[BufferedReader],
+ protected val out: JPrintWriter,
+ val master: Option[String]
+) extends AnyRef with LoopCommands with SparkILoopInit with Logging {
+ def this(in0: BufferedReader, out: JPrintWriter, master: String) = this(Some(in0), out, Some(master))
+ def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out, None)
+ def this() = this(None, new JPrintWriter(Console.out, true), None)
+
+ private var in: InteractiveReader = _ // the input stream from which commands come
+
+ // NOTE: Exposed in package for testing
+ private[repl] var settings: Settings = _
+
+ private[repl] var intp: SparkIMain = _
+
+ @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp
+ @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: SparkIMain): Unit = intp = i
+
+ /** Having inherited the difficult "var-ness" of the repl instance,
+ * I'm trying to work around it by moving operations into a class from
+ * which it will appear a stable prefix.
+ */
+ private def onIntp[T](f: SparkIMain => T): T = f(intp)
+
+ class IMainOps[T <: SparkIMain](val intp: T) {
+ import intp._
+ import global._
+
+ def printAfterTyper(msg: => String) =
+ intp.reporter printMessage afterTyper(msg)
+
+ /** Strip NullaryMethodType artifacts. */
+ private def replInfo(sym: Symbol) = {
+ sym.info match {
+ case NullaryMethodType(restpe) if sym.isAccessor => restpe
+ case info => info
+ }
+ }
+ def echoTypeStructure(sym: Symbol) =
+ printAfterTyper("" + deconstruct.show(replInfo(sym)))
+
+ def echoTypeSignature(sym: Symbol, verbose: Boolean) = {
+ if (verbose) SparkILoop.this.echo("// Type signature")
+ printAfterTyper("" + replInfo(sym))
+
+ if (verbose) {
+ SparkILoop.this.echo("\n// Internal Type structure")
+ echoTypeStructure(sym)
+ }
+ }
+ }
+ implicit def stabilizeIMain(intp: SparkIMain) = new IMainOps[intp.type](intp)
+
+ /** TODO -
+ * -n normalize
+ * -l label with case class parameter names
+ * -c complete - leave nothing out
+ */
+ private def typeCommandInternal(expr: String, verbose: Boolean): Result = {
+ onIntp { intp =>
+ val sym = intp.symbolOfLine(expr)
+ if (sym.exists) intp.echoTypeSignature(sym, verbose)
+ else ""
+ }
+ }
+
+ // NOTE: Must be public for visibility
+ @DeveloperApi
+ var sparkContext: SparkContext = _
+
+ override def echoCommandMessage(msg: String) {
+ intp.reporter printMessage msg
+ }
+
+ // def isAsync = !settings.Yreplsync.value
+ private[repl] def isAsync = false
+ // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals])
+ private def history = in.history
+
+ /** The context class loader at the time this object was created */
+ protected val originalClassLoader = Utils.getContextOrSparkClassLoader
+
+ // classpath entries added via :cp
+ private var addedClasspath: String = ""
+
+ /** A reverse list of commands to replay if the user requests a :replay */
+ private var replayCommandStack: List[String] = Nil
+
+ /** A list of commands to replay if the user requests a :replay */
+ private def replayCommands = replayCommandStack.reverse
+
+ /** Record a command for replay should the user request a :replay */
+ private def addReplay(cmd: String) = replayCommandStack ::= cmd
+
+ private def savingReplayStack[T](body: => T): T = {
+ val saved = replayCommandStack
+ try body
+ finally replayCommandStack = saved
+ }
+ private def savingReader[T](body: => T): T = {
+ val saved = in
+ try body
+ finally in = saved
+ }
+
+
+ private def sparkCleanUp(){
+ echo("Stopping spark context.")
+ intp.beQuietDuring {
+ command("sc.stop()")
+ }
+ }
+ /** Close the interpreter and set the var to null. */
+ private def closeInterpreter() {
+ if (intp ne null) {
+ sparkCleanUp()
+ intp.close()
+ intp = null
+ }
+ }
+
+ class SparkILoopInterpreter extends SparkIMain(settings, out) {
+ outer =>
+
+ override private[repl] lazy val formatting = new Formatting {
+ def prompt = SparkILoop.this.prompt
+ }
+ override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader)
+ }
+
+ /**
+ * Constructs a new interpreter.
+ */
+ protected def createInterpreter() {
+ require(settings != null)
+
+ if (addedClasspath != "") settings.classpath.append(addedClasspath)
+ val addedJars =
+ if (Utils.isWindows) {
+ // Strip any URI scheme prefix so we can add the correct path to the classpath
+ // e.g. file:/C:/my/path.jar -> C:/my/path.jar
+ SparkILoop.getAddedJars.map { jar => new URI(jar).getPath.stripPrefix("/") }
+ } else {
+ SparkILoop.getAddedJars
+ }
+ // work around for Scala bug
+ val totalClassPath = addedJars.foldLeft(
+ settings.classpath.value)((l, r) => ClassPath.join(l, r))
+ this.settings.classpath.value = totalClassPath
+
+ intp = new SparkILoopInterpreter
+ }
+
+ /** print a friendly help message */
+ private def helpCommand(line: String): Result = {
+ if (line == "") helpSummary()
+ else uniqueCommand(line) match {
+ case Some(lc) => echo("\n" + lc.longHelp)
+ case _ => ambiguousError(line)
+ }
+ }
+ private def helpSummary() = {
+ val usageWidth = commands map (_.usageMsg.length) max
+ val formatStr = "%-" + usageWidth + "s %s %s"
+
+ echo("All commands can be abbreviated, e.g. :he instead of :help.")
+ echo("Those marked with a * have more detailed help, e.g. :help imports.\n")
+
+ commands foreach { cmd =>
+ val star = if (cmd.hasLongHelp) "*" else " "
+ echo(formatStr.format(cmd.usageMsg, star, cmd.help))
+ }
+ }
+ private def ambiguousError(cmd: String): Result = {
+ matchingCommands(cmd) match {
+ case Nil => echo(cmd + ": no such command. Type :help for help.")
+ case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?")
+ }
+ Result(true, None)
+ }
+ private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd)
+ private def uniqueCommand(cmd: String): Option[LoopCommand] = {
+ // this lets us add commands willy-nilly and only requires enough command to disambiguate
+ matchingCommands(cmd) match {
+ case List(x) => Some(x)
+ // exact match OK even if otherwise appears ambiguous
+ case xs => xs find (_.name == cmd)
+ }
+ }
+ private var fallbackMode = false
+
+ private def toggleFallbackMode() {
+ val old = fallbackMode
+ fallbackMode = !old
+ System.setProperty("spark.repl.fallback", fallbackMode.toString)
+ echo(s"""
+ |Switched ${if (old) "off" else "on"} fallback mode without restarting.
+ | If you have defined classes in the repl, it would
+ |be good to redefine them incase you plan to use them. If you still run
+ |into issues it would be good to restart the repl and turn on `:fallback`
+ |mode as first command.
+ """.stripMargin)
+ }
+
+ /** Show the history */
+ private lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
+ override def usage = "[num]"
+ def defaultLines = 20
+
+ def apply(line: String): Result = {
+ if (history eq NoHistory)
+ return "No history available."
+
+ val xs = words(line)
+ val current = history.index
+ val count = try xs.head.toInt catch { case _: Exception => defaultLines }
+ val lines = history.asStrings takeRight count
+ val offset = current - lines.size + 1
+
+ for ((line, index) <- lines.zipWithIndex)
+ echo("%3d %s".format(index + offset, line))
+ }
+ }
+
+ // When you know you are most likely breaking into the middle
+ // of a line being typed. This softens the blow.
+ private[repl] def echoAndRefresh(msg: String) = {
+ echo("\n" + msg)
+ in.redrawLine()
+ }
+ private[repl] def echo(msg: String) = {
+ out println msg
+ out.flush()
+ }
+ private def echoNoNL(msg: String) = {
+ out print msg
+ out.flush()
+ }
+
+ /** Search the history */
+ private def searchHistory(_cmdline: String) {
+ val cmdline = _cmdline.toLowerCase
+ val offset = history.index - history.size + 1
+
+ for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline)
+ echo("%d %s".format(index + offset, line))
+ }
+
+ private var currentPrompt = Properties.shellPromptString
+
+ /**
+ * Sets the prompt string used by the REPL.
+ *
+ * @param prompt The new prompt string
+ */
+ @DeveloperApi
+ def setPrompt(prompt: String) = currentPrompt = prompt
+
+ /**
+ * Represents the current prompt string used by the REPL.
+ *
+ * @return The current prompt string
+ */
+ @DeveloperApi
+ def prompt = currentPrompt
+
+ import LoopCommand.{ cmd, nullary }
+
+ /** Standard commands */
+ private lazy val standardCommands = List(
+ cmd("cp", "", "add a jar or directory to the classpath", addClasspath),
+ cmd("help", "[command]", "print this summary or command-specific help", helpCommand),
+ historyCommand,
+ cmd("h?", "", "search the history", searchHistory),
+ cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand),
+ cmd("implicits", "[-v]", "show the implicits in scope", implicitsCommand),
+ cmd("javap", "", "disassemble a file or class name", javapCommand),
+ cmd("load", "", "load and interpret a Scala file", loadCommand),
+ nullary("paste", "enter paste mode: all input up to ctrl-D compiled together", pasteCommand),
+// nullary("power", "enable power user mode", powerCmd),
+ nullary("quit", "exit the repl", () => Result(false, None)),
+ nullary("replay", "reset execution and replay all previous commands", replay),
+ nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand),
+ shCommand,
+ nullary("silent", "disable/enable automatic printing of results", verbosity),
+ nullary("fallback", """
+ |disable/enable advanced repl changes, these fix some issues but may introduce others.
+ |This mode will be removed once these fixes stablize""".stripMargin, toggleFallbackMode),
+ cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand),
+ nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand)
+ )
+
+ /** Power user commands */
+ private lazy val powerCommands: List[LoopCommand] = List(
+ // cmd("phase", "", "set the implicit phase for power commands", phaseCommand)
+ )
+
+ // private def dumpCommand(): Result = {
+ // echo("" + power)
+ // history.asStrings takeRight 30 foreach echo
+ // in.redrawLine()
+ // }
+ // private def valsCommand(): Result = power.valsDescription
+
+ private val typeTransforms = List(
+ "scala.collection.immutable." -> "immutable.",
+ "scala.collection.mutable." -> "mutable.",
+ "scala.collection.generic." -> "generic.",
+ "java.lang." -> "jl.",
+ "scala.runtime." -> "runtime."
+ )
+
+ private def importsCommand(line: String): Result = {
+ val tokens = words(line)
+ val handlers = intp.languageWildcardHandlers ++ intp.importHandlers
+ val isVerbose = tokens contains "-v"
+
+ handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach {
+ case (handler, idx) =>
+ val (types, terms) = handler.importedSymbols partition (_.name.isTypeName)
+ val imps = handler.implicitSymbols
+ val found = tokens filter (handler importsSymbolNamed _)
+ val typeMsg = if (types.isEmpty) "" else types.size + " types"
+ val termMsg = if (terms.isEmpty) "" else terms.size + " terms"
+ val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit"
+ val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "")
+ val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")")
+
+ intp.reporter.printMessage("%2d) %-30s %s%s".format(
+ idx + 1,
+ handler.importString,
+ statsMsg,
+ foundMsg
+ ))
+ }
+ }
+
+ private def implicitsCommand(line: String): Result = onIntp { intp =>
+ import intp._
+ import global._
+
+ def p(x: Any) = intp.reporter.printMessage("" + x)
+
+ // If an argument is given, only show a source with that
+ // in its name somewhere.
+ val args = line split "\\s+"
+ val filtered = intp.implicitSymbolsBySource filter {
+ case (source, syms) =>
+ (args contains "-v") || {
+ if (line == "") (source.fullName.toString != "scala.Predef")
+ else (args exists (source.name.toString contains _))
+ }
+ }
+
+ if (filtered.isEmpty)
+ return "No implicits have been imported other than those in Predef."
+
+ filtered foreach {
+ case (source, syms) =>
+ p("/* " + syms.size + " implicit members imported from " + source.fullName + " */")
+
+ // This groups the members by where the symbol is defined
+ val byOwner = syms groupBy (_.owner)
+ val sortedOwners = byOwner.toList sortBy { case (owner, _) => afterTyper(source.info.baseClasses indexOf owner) }
+
+ sortedOwners foreach {
+ case (owner, members) =>
+ // Within each owner, we cluster results based on the final result type
+ // if there are more than a couple, and sort each cluster based on name.
+ // This is really just trying to make the 100 or so implicits imported
+ // by default into something readable.
+ val memberGroups: List[List[Symbol]] = {
+ val groups = members groupBy (_.tpe.finalResultType) toList
+ val (big, small) = groups partition (_._2.size > 3)
+ val xss = (
+ (big sortBy (_._1.toString) map (_._2)) :+
+ (small flatMap (_._2))
+ )
+
+ xss map (xs => xs sortBy (_.name.toString))
+ }
+
+ val ownerMessage = if (owner == source) " defined in " else " inherited from "
+ p(" /* " + members.size + ownerMessage + owner.fullName + " */")
+
+ memberGroups foreach { group =>
+ group foreach (s => p(" " + intp.symbolDefString(s)))
+ p("")
+ }
+ }
+ p("")
+ }
+ }
+
+ private def findToolsJar() = {
+ val jdkPath = Directory(jdkHome)
+ val jar = jdkPath / "lib" / "tools.jar" toFile;
+
+ if (jar isFile)
+ Some(jar)
+ else if (jdkPath.isDirectory)
+ jdkPath.deepFiles find (_.name == "tools.jar")
+ else None
+ }
+ private def addToolsJarToLoader() = {
+ val cl = findToolsJar match {
+ case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader)
+ case _ => intp.classLoader
+ }
+ if (Javap.isAvailable(cl)) {
+ logDebug(":javap available.")
+ cl
+ }
+ else {
+ logDebug(":javap unavailable: no tools.jar at " + jdkHome)
+ intp.classLoader
+ }
+ }
+
+ private def newJavap() = new JavapClass(addToolsJarToLoader(), new SparkIMain.ReplStrippingWriter(intp)) {
+ override def tryClass(path: String): Array[Byte] = {
+ val hd :: rest = path split '.' toList;
+ // If there are dots in the name, the first segment is the
+ // key to finding it.
+ if (rest.nonEmpty) {
+ intp optFlatName hd match {
+ case Some(flat) =>
+ val clazz = flat :: rest mkString NAME_JOIN_STRING
+ val bytes = super.tryClass(clazz)
+ if (bytes.nonEmpty) bytes
+ else super.tryClass(clazz + MODULE_SUFFIX_STRING)
+ case _ => super.tryClass(path)
+ }
+ }
+ else {
+ // Look for Foo first, then Foo$, but if Foo$ is given explicitly,
+ // we have to drop the $ to find object Foo, then tack it back onto
+ // the end of the flattened name.
+ def className = intp flatName path
+ def moduleName = (intp flatName path.stripSuffix(MODULE_SUFFIX_STRING)) + MODULE_SUFFIX_STRING
+
+ val bytes = super.tryClass(className)
+ if (bytes.nonEmpty) bytes
+ else super.tryClass(moduleName)
+ }
+ }
+ }
+ // private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap())
+ private lazy val javap =
+ try newJavap()
+ catch { case _: Exception => null }
+
+ // Still todo: modules.
+ private def typeCommand(line0: String): Result = {
+ line0.trim match {
+ case "" => ":type [-v] "
+ case s if s startsWith "-v " => typeCommandInternal(s stripPrefix "-v " trim, true)
+ case s => typeCommandInternal(s, false)
+ }
+ }
+
+ private def warningsCommand(): Result = {
+ if (intp.lastWarnings.isEmpty)
+ "Can't find any cached warnings."
+ else
+ intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) }
+ }
+
+ private def javapCommand(line: String): Result = {
+ if (javap == null)
+ ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome)
+ else if (javaVersion startsWith "1.7")
+ ":javap not yet working with java 1.7"
+ else if (line == "")
+ ":javap [-lcsvp] [path1 path2 ...]"
+ else
+ javap(words(line)) foreach { res =>
+ if (res.isError) return "Failed: " + res.value
+ else res.show()
+ }
+ }
+
+ private def wrapCommand(line: String): Result = {
+ def failMsg = "Argument to :wrap must be the name of a method with signature [T](=> T): T"
+ onIntp { intp =>
+ import intp._
+ import global._
+
+ words(line) match {
+ case Nil =>
+ intp.executionWrapper match {
+ case "" => "No execution wrapper is set."
+ case s => "Current execution wrapper: " + s
+ }
+ case "clear" :: Nil =>
+ intp.executionWrapper match {
+ case "" => "No execution wrapper is set."
+ case s => intp.clearExecutionWrapper() ; "Cleared execution wrapper."
+ }
+ case wrapper :: Nil =>
+ intp.typeOfExpression(wrapper) match {
+ case PolyType(List(targ), MethodType(List(arg), restpe)) =>
+ intp setExecutionWrapper intp.pathToTerm(wrapper)
+ "Set wrapper to '" + wrapper + "'"
+ case tp =>
+ failMsg + "\nFound: "
+ }
+ case _ => failMsg
+ }
+ }
+ }
+
+ private def pathToPhaseWrapper = intp.pathToTerm("$r") + ".phased.atCurrent"
+ // private def phaseCommand(name: String): Result = {
+ // val phased: Phased = power.phased
+ // import phased.NoPhaseName
+
+ // if (name == "clear") {
+ // phased.set(NoPhaseName)
+ // intp.clearExecutionWrapper()
+ // "Cleared active phase."
+ // }
+ // else if (name == "") phased.get match {
+ // case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)"
+ // case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get)
+ // }
+ // else {
+ // val what = phased.parse(name)
+ // if (what.isEmpty || !phased.set(what))
+ // "'" + name + "' does not appear to represent a valid phase."
+ // else {
+ // intp.setExecutionWrapper(pathToPhaseWrapper)
+ // val activeMessage =
+ // if (what.toString.length == name.length) "" + what
+ // else "%s (%s)".format(what, name)
+
+ // "Active phase is now: " + activeMessage
+ // }
+ // }
+ // }
+
+ /**
+ * Provides a list of available commands.
+ *
+ * @return The list of commands
+ */
+ @DeveloperApi
+ def commands: List[LoopCommand] = standardCommands /*++ (
+ if (isReplPower) powerCommands else Nil
+ )*/
+
+ private val replayQuestionMessage =
+ """|That entry seems to have slain the compiler. Shall I replay
+ |your session? I can re-run each line except the last one.
+ |[y/n]
+ """.trim.stripMargin
+
+ private def crashRecovery(ex: Throwable): Boolean = {
+ echo(ex.toString)
+ ex match {
+ case _: NoSuchMethodError | _: NoClassDefFoundError =>
+ echo("\nUnrecoverable error.")
+ throw ex
+ case _ =>
+ def fn(): Boolean =
+ try in.readYesOrNo(replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
+ catch { case _: RuntimeException => false }
+
+ if (fn()) replay()
+ else echo("\nAbandoning crashed session.")
+ }
+ true
+ }
+
+ /** The main read-eval-print loop for the repl. It calls
+ * command() for each line of input, and stops when
+ * command() returns false.
+ */
+ private def loop() {
+ def readOneLine() = {
+ out.flush()
+ in readLine prompt
+ }
+ // return false if repl should exit
+ def processLine(line: String): Boolean = {
+ if (isAsync) {
+ if (!awaitInitialized()) return false
+ runThunks()
+ }
+ if (line eq null) false // assume null means EOF
+ else command(line) match {
+ case Result(false, _) => false
+ case Result(_, Some(finalLine)) => addReplay(finalLine) ; true
+ case _ => true
+ }
+ }
+ def innerLoop() {
+ val shouldContinue = try {
+ processLine(readOneLine())
+ } catch {case t: Throwable => crashRecovery(t)}
+ if (shouldContinue)
+ innerLoop()
+ }
+ innerLoop()
+ }
+
+ /** interpret all lines from a specified file */
+ private def interpretAllFrom(file: File) {
+ savingReader {
+ savingReplayStack {
+ file applyReader { reader =>
+ in = SimpleReader(reader, out, false)
+ echo("Loading " + file + "...")
+ loop()
+ }
+ }
+ }
+ }
+
+ /** create a new interpreter and replay the given commands */
+ private def replay() {
+ reset()
+ if (replayCommandStack.isEmpty)
+ echo("Nothing to replay.")
+ else for (cmd <- replayCommands) {
+ echo("Replaying: " + cmd) // flush because maybe cmd will have its own output
+ command(cmd)
+ echo("")
+ }
+ }
+ private def resetCommand() {
+ echo("Resetting repl state.")
+ if (replayCommandStack.nonEmpty) {
+ echo("Forgetting this session history:\n")
+ replayCommands foreach echo
+ echo("")
+ replayCommandStack = Nil
+ }
+ if (intp.namedDefinedTerms.nonEmpty)
+ echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", "))
+ if (intp.definedTypes.nonEmpty)
+ echo("Forgetting defined types: " + intp.definedTypes.mkString(", "))
+
+ reset()
+ }
+
+ private def reset() {
+ intp.reset()
+ // unleashAndSetPhase()
+ }
+
+ /** fork a shell and run a command */
+ private lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
+ override def usage = ""
+ def apply(line: String): Result = line match {
+ case "" => showUsage()
+ case _ =>
+ val toRun = classOf[ProcessResult].getName + "(" + string2codeQuoted(line) + ")"
+ intp interpret toRun
+ ()
+ }
+ }
+
+ private def withFile(filename: String)(action: File => Unit) {
+ val f = File(filename)
+
+ if (f.exists) action(f)
+ else echo("That file does not exist")
+ }
+
+ private def loadCommand(arg: String) = {
+ var shouldReplay: Option[String] = None
+ withFile(arg)(f => {
+ interpretAllFrom(f)
+ shouldReplay = Some(":load " + arg)
+ })
+ Result(true, shouldReplay)
+ }
+
+ private def addAllClasspath(args: Seq[String]): Unit = {
+ var added = false
+ var totalClasspath = ""
+ for (arg <- args) {
+ val f = File(arg).normalize
+ if (f.exists) {
+ added = true
+ addedClasspath = ClassPath.join(addedClasspath, f.path)
+ totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
+ intp.addUrlsToClassPath(f.toURI.toURL)
+ sparkContext.addJar(f.toURI.toURL.getPath)
+ }
+ }
+ }
+
+ private def addClasspath(arg: String): Unit = {
+ val f = File(arg).normalize
+ if (f.exists) {
+ addedClasspath = ClassPath.join(addedClasspath, f.path)
+ intp.addUrlsToClassPath(f.toURI.toURL)
+ sparkContext.addJar(f.toURI.toURL.getPath)
+ echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, intp.global.classPath.asClasspathString))
+ }
+ else echo("The path '" + f + "' doesn't seem to exist.")
+ }
+
+
+ private def powerCmd(): Result = {
+ if (isReplPower) "Already in power mode."
+ else enablePowerMode(false)
+ }
+
+ private[repl] def enablePowerMode(isDuringInit: Boolean) = {
+ // replProps.power setValue true
+ // unleashAndSetPhase()
+ // asyncEcho(isDuringInit, power.banner)
+ }
+ // private def unleashAndSetPhase() {
+// if (isReplPower) {
+// // power.unleash()
+// // Set the phase to "typer"
+// intp beSilentDuring phaseCommand("typer")
+// }
+// }
+
+ private def asyncEcho(async: Boolean, msg: => String) {
+ if (async) asyncMessage(msg)
+ else echo(msg)
+ }
+
+ private def verbosity() = {
+ // val old = intp.printResults
+ // intp.printResults = !old
+ // echo("Switched " + (if (old) "off" else "on") + " result printing.")
+ }
+
+ /** Run one command submitted by the user. Two values are returned:
+ * (1) whether to keep running, (2) the line to record for replay,
+ * if any. */
+ private[repl] def command(line: String): Result = {
+ if (line startsWith ":") {
+ val cmd = line.tail takeWhile (x => !x.isWhitespace)
+ uniqueCommand(cmd) match {
+ case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace))
+ case _ => ambiguousError(cmd)
+ }
+ }
+ else if (intp.global == null) Result(false, None) // Notice failure to create compiler
+ else Result(true, interpretStartingWith(line))
+ }
+
+ private def readWhile(cond: String => Boolean) = {
+ Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
+ }
+
+ private def pasteCommand(): Result = {
+ echo("// Entering paste mode (ctrl-D to finish)\n")
+ val code = readWhile(_ => true) mkString "\n"
+ echo("\n// Exiting paste mode, now interpreting.\n")
+ intp interpret code
+ ()
+ }
+
+ private object paste extends Pasted {
+ val ContinueString = " | "
+ val PromptString = "scala> "
+
+ def interpret(line: String): Unit = {
+ echo(line.trim)
+ intp interpret line
+ echo("")
+ }
+
+ def transcript(start: String) = {
+ echo("\n// Detected repl transcript paste: ctrl-D to finish.\n")
+ apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim))
+ }
+ }
+ import paste.{ ContinueString, PromptString }
+
+ /** Interpret expressions starting with the first line.
+ * Read lines until a complete compilation unit is available
+ * or until a syntax error has been seen. If a full unit is
+ * read, go ahead and interpret it. Return the full string
+ * to be recorded for replay, if any.
+ */
+ private def interpretStartingWith(code: String): Option[String] = {
+ // signal completion non-completion input has been received
+ in.completion.resetVerbosity()
+
+ def reallyInterpret = {
+ val reallyResult = intp.interpret(code)
+ (reallyResult, reallyResult match {
+ case IR.Error => None
+ case IR.Success => Some(code)
+ case IR.Incomplete =>
+ if (in.interactive && code.endsWith("\n\n")) {
+ echo("You typed two blank lines. Starting a new command.")
+ None
+ }
+ else in.readLine(ContinueString) match {
+ case null =>
+ // we know compilation is going to fail since we're at EOF and the
+ // parser thinks the input is still incomplete, but since this is
+ // a file being read non-interactively we want to fail. So we send
+ // it straight to the compiler for the nice error message.
+ intp.compileString(code)
+ None
+
+ case line => interpretStartingWith(code + "\n" + line)
+ }
+ })
+ }
+
+ /** Here we place ourselves between the user and the interpreter and examine
+ * the input they are ostensibly submitting. We intervene in several cases:
+ *
+ * 1) If the line starts with "scala> " it is assumed to be an interpreter paste.
+ * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
+ * on the previous result.
+ * 3) If the Completion object's execute returns Some(_), we inject that value
+ * and avoid the interpreter, as it's likely not valid scala code.
+ */
+ if (code == "") None
+ else if (!paste.running && code.trim.startsWith(PromptString)) {
+ paste.transcript(code)
+ None
+ }
+ else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") {
+ interpretStartingWith(intp.mostRecentVar + code)
+ }
+ else if (code.trim startsWith "//") {
+ // line comment, do nothing
+ None
+ }
+ else
+ reallyInterpret._2
+ }
+
+ // runs :load `file` on any files passed via -i
+ private def loadFiles(settings: Settings) = settings match {
+ case settings: SparkRunnerSettings =>
+ for (filename <- settings.loadfiles.value) {
+ val cmd = ":load " + filename
+ command(cmd)
+ addReplay(cmd)
+ echo("")
+ }
+ case _ =>
+ }
+
+ /** Tries to create a JLineReader, falling back to SimpleReader:
+ * unless settings or properties are such that it should start
+ * with SimpleReader.
+ */
+ private def chooseReader(settings: Settings): InteractiveReader = {
+ if (settings.Xnojline.value || Properties.isEmacsShell)
+ SimpleReader()
+ else try new SparkJLineReader(
+ if (settings.noCompletion.value) NoCompletion
+ else new SparkJLineCompletion(intp)
+ )
+ catch {
+ case ex @ (_: Exception | _: NoClassDefFoundError) =>
+ echo("Failed to created SparkJLineReader: " + ex + "\nFalling back to SimpleReader.")
+ SimpleReader()
+ }
+ }
+
+ private val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe
+ private val m = u.runtimeMirror(Utils.getSparkClassLoader)
+ private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
+ u.TypeTag[T](
+ m,
+ new TypeCreator {
+ def apply[U <: ApiUniverse with Singleton](m: Mirror[U]): U # Type =
+ m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type]
+ })
+
+ private def process(settings: Settings): Boolean = savingContextLoader {
+ if (getMaster() == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
+
+ this.settings = settings
+ createInterpreter()
+
+ // sets in to some kind of reader depending on environmental cues
+ in = in0 match {
+ case Some(reader) => SimpleReader(reader, out, true)
+ case None =>
+ // some post-initialization
+ chooseReader(settings) match {
+ case x: SparkJLineReader => addThunk(x.consoleReader.postInit) ; x
+ case x => x
+ }
+ }
+ lazy val tagOfSparkIMain = tagOfStaticClass[org.apache.spark.repl.SparkIMain]
+ // Bind intp somewhere out of the regular namespace where
+ // we can get at it in generated code.
+ addThunk(intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfSparkIMain, classTag[SparkIMain])))
+ addThunk({
+ import scala.tools.nsc.io._
+ import Properties.userHome
+ import scala.compat.Platform.EOL
+ val autorun = replProps.replAutorunCode.option flatMap (f => io.File(f).safeSlurp())
+ if (autorun.isDefined) intp.quietRun(autorun.get)
+ })
+
+ addThunk(printWelcome())
+ addThunk(initializeSpark())
+
+ // it is broken on startup; go ahead and exit
+ if (intp.reporter.hasErrors)
+ return false
+
+ // This is about the illusion of snappiness. We call initialize()
+ // which spins off a separate thread, then print the prompt and try
+ // our best to look ready. The interlocking lazy vals tend to
+ // inter-deadlock, so we break the cycle with a single asynchronous
+ // message to an actor.
+ if (isAsync) {
+ intp initialize initializedCallback()
+ createAsyncListener() // listens for signal to run postInitialization
+ }
+ else {
+ intp.initializeSynchronous()
+ postInitialization()
+ }
+ // printWelcome()
+
+ loadFiles(settings)
+
+ try loop()
+ catch AbstractOrMissingHandler()
+ finally closeInterpreter()
+
+ true
+ }
+
+ // NOTE: Must be public for visibility
+ @DeveloperApi
+ def createSparkContext(): SparkContext = {
+ val execUri = System.getenv("SPARK_EXECUTOR_URI")
+ val jars = SparkILoop.getAddedJars
+ val conf = new SparkConf()
+ .setMaster(getMaster())
+ .setAppName("Spark shell")
+ .setJars(jars)
+ .set("spark.repl.class.uri", intp.classServerUri)
+ if (execUri != null) {
+ conf.set("spark.executor.uri", execUri)
+ }
+ sparkContext = new SparkContext(conf)
+ logInfo("Created spark context..")
+ sparkContext
+ }
+
+ private def getMaster(): String = {
+ val master = this.master match {
+ case Some(m) => m
+ case None =>
+ val envMaster = sys.env.get("MASTER")
+ val propMaster = sys.props.get("spark.master")
+ propMaster.orElse(envMaster).getOrElse("local[*]")
+ }
+ master
+ }
+
+ /** process command-line arguments and do as they request */
+ def process(args: Array[String]): Boolean = {
+ val command = new SparkCommandLine(args.toList, msg => echo(msg))
+ def neededHelp(): String =
+ (if (command.settings.help.value) command.usageMsg + "\n" else "") +
+ (if (command.settings.Xhelp.value) command.xusageMsg + "\n" else "")
+
+ // if they asked for no help and command is valid, we call the real main
+ neededHelp() match {
+ case "" => command.ok && process(command.settings)
+ case help => echoNoNL(help) ; true
+ }
+ }
+
+ @deprecated("Use `process` instead", "2.9.0")
+ private def main(settings: Settings): Unit = process(settings)
+}
+
+object SparkILoop {
+ implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp
+ private def echo(msg: String) = Console println msg
+
+ def getAddedJars: Array[String] = {
+ val envJars = sys.env.get("ADD_JARS")
+ val propJars = sys.props.get("spark.jars").flatMap { p =>
+ if (p == "") None else Some(p)
+ }
+ val jars = propJars.orElse(envJars).getOrElse("")
+ Utils.resolveURIs(jars).split(",").filter(_.nonEmpty)
+ }
+
+ // Designed primarily for use by test code: take a String with a
+ // bunch of code, and prints out a transcript of what it would look
+ // like if you'd just typed it into the repl.
+ private[repl] def runForTranscript(code: String, settings: Settings): String = {
+ import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
+
+ stringFromStream { ostream =>
+ Console.withOut(ostream) {
+ val output = new JPrintWriter(new OutputStreamWriter(ostream), true) {
+ override def write(str: String) = {
+ // completely skip continuation lines
+ if (str forall (ch => ch.isWhitespace || ch == '|')) ()
+ // print a newline on empty scala prompts
+ else if ((str contains '\n') && (str.trim == "scala> ")) super.write("\n")
+ else super.write(str)
+ }
+ }
+ val input = new BufferedReader(new StringReader(code)) {
+ override def readLine(): String = {
+ val s = super.readLine()
+ // helping out by printing the line being interpreted.
+ if (s != null)
+ output.println(s)
+ s
+ }
+ }
+ val repl = new SparkILoop(input, output)
+
+ if (settings.classpath.isDefault)
+ settings.classpath.value = sys.props("java.class.path")
+
+ getAddedJars.foreach(settings.classpath.append(_))
+
+ repl process settings
+ }
+ }
+ }
+
+ /** Creates an interpreter loop with default settings and feeds
+ * the given code to it as input.
+ */
+ private[repl] def run(code: String, sets: Settings = new Settings): String = {
+ import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
+
+ stringFromStream { ostream =>
+ Console.withOut(ostream) {
+ val input = new BufferedReader(new StringReader(code))
+ val output = new JPrintWriter(new OutputStreamWriter(ostream), true)
+ val repl = new ILoop(input, output)
+
+ if (sets.classpath.isDefault)
+ sets.classpath.value = sys.props("java.class.path")
+
+ repl process sets
+ }
+ }
+ }
+ private[repl] def run(lines: List[String]): String = run(lines map (_ + "\n") mkString)
+}
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
similarity index 95%
rename from repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
index 7667a9c11979e..99bd777c04fdb 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
@@ -19,7 +19,7 @@ import org.apache.spark.SPARK_VERSION
/**
* Machinery for the asynchronous initialization of the repl.
*/
-trait SparkILoopInit {
+private[repl] trait SparkILoopInit {
self: SparkILoop =>
/** Print a welcome message */
@@ -121,11 +121,14 @@ trait SparkILoopInit {
def initializeSpark() {
intp.beQuietDuring {
command("""
- @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext();
+ @transient val sc = {
+ val _sc = org.apache.spark.repl.Main.interp.createSparkContext()
+ println("Spark context available as sc.")
+ _sc
+ }
""")
command("import org.apache.spark.SparkContext._")
}
- echo("Spark context available as sc.")
}
// code to be executed only after the interpreter is initialized
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
new file mode 100644
index 0000000000000..35fb625645022
--- /dev/null
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -0,0 +1,1817 @@
+// scalastyle:off
+
+/* NSC -- new Scala compiler
+ * Copyright 2005-2013 LAMP/EPFL
+ * @author Martin Odersky
+ */
+
+package org.apache.spark.repl
+
+import java.io.File
+
+import scala.tools.nsc._
+import scala.tools.nsc.backend.JavaPlatform
+import scala.tools.nsc.interpreter._
+
+import Predef.{ println => _, _ }
+import scala.tools.nsc.util.{MergedClassPath, stringFromWriter, ScalaClassLoader, stackTraceString}
+import scala.reflect.internal.util._
+import java.net.URL
+import scala.sys.BooleanProp
+import io.{AbstractFile, PlainFile, VirtualDirectory}
+
+import reporters._
+import symtab.Flags
+import scala.reflect.internal.Names
+import scala.tools.util.PathResolver
+import ScalaClassLoader.URLClassLoader
+import scala.tools.nsc.util.Exceptional.unwrap
+import scala.collection.{ mutable, immutable }
+import scala.util.control.Exception.{ ultimately }
+import SparkIMain._
+import java.util.concurrent.Future
+import typechecker.Analyzer
+import scala.language.implicitConversions
+import scala.reflect.runtime.{ universe => ru }
+import scala.reflect.{ ClassTag, classTag }
+import scala.tools.reflect.StdRuntimeTags._
+import scala.util.control.ControlThrowable
+
+import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf}
+import org.apache.spark.util.Utils
+import org.apache.spark.annotation.DeveloperApi
+
+// /** directory to save .class files to */
+// private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) {
+// private def pp(root: AbstractFile, indentLevel: Int) {
+// val spaces = " " * indentLevel
+// out.println(spaces + root.name)
+// if (root.isDirectory)
+// root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1))
+// }
+// // print the contents hierarchically
+// def show() = pp(this, 0)
+// }
+
+ /** An interpreter for Scala code.
+ *
+ * The main public entry points are compile(), interpret(), and bind().
+ * The compile() method loads a complete Scala file. The interpret() method
+ * executes one line of Scala code at the request of the user. The bind()
+ * method binds an object to a variable that can then be used by later
+ * interpreted code.
+ *
+ * The overall approach is based on compiling the requested code and then
+ * using a Java classloader and Java reflection to run the code
+ * and access its results.
+ *
+ * In more detail, a single compiler instance is used
+ * to accumulate all successfully compiled or interpreted Scala code. To
+ * "interpret" a line of code, the compiler generates a fresh object that
+ * includes the line of code and which has public member(s) to export
+ * all variables defined by that code. To extract the result of an
+ * interpreted line to show the user, a second "result object" is created
+ * which imports the variables exported by the above object and then
+ * exports members called "$eval" and "$print". To accomodate user expressions
+ * that read from variables or methods defined in previous statements, "import"
+ * statements are used.
+ *
+ * This interpreter shares the strengths and weaknesses of using the
+ * full compiler-to-Java. The main strength is that interpreted code
+ * behaves exactly as does compiled code, including running at full speed.
+ * The main weakness is that redefining classes and methods is not handled
+ * properly, because rebinding at the Java level is technically difficult.
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ */
+ @DeveloperApi
+ class SparkIMain(
+ initialSettings: Settings,
+ val out: JPrintWriter,
+ propagateExceptions: Boolean = false)
+ extends SparkImports with Logging { imain =>
+
+ private val conf = new SparkConf()
+
+ private val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1")
+ /** Local directory to save .class files too */
+ private lazy val outputDir = {
+ val tmp = System.getProperty("java.io.tmpdir")
+ val rootDir = conf.get("spark.repl.classdir", tmp)
+ Utils.createTempDir(rootDir)
+ }
+ if (SPARK_DEBUG_REPL) {
+ echo("Output directory: " + outputDir)
+ }
+
+ /**
+ * Returns the path to the output directory containing all generated
+ * class files that will be served by the REPL class server.
+ */
+ @DeveloperApi
+ lazy val getClassOutputDirectory = outputDir
+
+ private val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles
+ /** Jetty server that will serve our classes to worker nodes */
+ private val classServerPort = conf.getInt("spark.replClassServer.port", 0)
+ private val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server")
+ private var currentSettings: Settings = initialSettings
+ private var printResults = true // whether to print result lines
+ private var totalSilence = false // whether to print anything
+ private var _initializeComplete = false // compiler is initialized
+ private var _isInitialized: Future[Boolean] = null // set up initialization future
+ private var bindExceptions = true // whether to bind the lastException variable
+ private var _executionWrapper = "" // code to be wrapped around all lines
+
+
+ // Start the classServer and store its URI in a spark system property
+ // (which will be passed to executors so that they can connect to it)
+ classServer.start()
+ if (SPARK_DEBUG_REPL) {
+ echo("Class server started, URI = " + classServer.uri)
+ }
+
+ /**
+ * URI of the class server used to feed REPL compiled classes.
+ *
+ * @return The string representing the class server uri
+ */
+ @DeveloperApi
+ def classServerUri = classServer.uri
+
+ /** We're going to go to some trouble to initialize the compiler asynchronously.
+ * It's critical that nothing call into it until it's been initialized or we will
+ * run into unrecoverable issues, but the perceived repl startup time goes
+ * through the roof if we wait for it. So we initialize it with a future and
+ * use a lazy val to ensure that any attempt to use the compiler object waits
+ * on the future.
+ */
+ private var _classLoader: AbstractFileClassLoader = null // active classloader
+ private val _compiler: Global = newCompiler(settings, reporter) // our private compiler
+
+ private trait ExposeAddUrl extends URLClassLoader { def addNewUrl(url: URL) = this.addURL(url) }
+ private var _runtimeClassLoader: URLClassLoader with ExposeAddUrl = null // wrapper exposing addURL
+
+ private val nextReqId = {
+ var counter = 0
+ () => { counter += 1 ; counter }
+ }
+
+ private def compilerClasspath: Seq[URL] = (
+ if (isInitializeComplete) global.classPath.asURLs
+ else new PathResolver(settings).result.asURLs // the compiler's classpath
+ )
+ // NOTE: Exposed to repl package since accessed indirectly from SparkIMain
+ private[repl] def settings = currentSettings
+ private def mostRecentLine = prevRequestList match {
+ case Nil => ""
+ case req :: _ => req.originalLine
+ }
+ // Run the code body with the given boolean settings flipped to true.
+ private def withoutWarnings[T](body: => T): T = beQuietDuring {
+ val saved = settings.nowarn.value
+ if (!saved)
+ settings.nowarn.value = true
+
+ try body
+ finally if (!saved) settings.nowarn.value = false
+ }
+
+ /** construct an interpreter that reports to Console */
+ def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
+ def this() = this(new Settings())
+
+ private lazy val repllog: Logger = new Logger {
+ val out: JPrintWriter = imain.out
+ val isInfo: Boolean = BooleanProp keyExists "scala.repl.info"
+ val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug"
+ val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace"
+ }
+ private[repl] lazy val formatting: Formatting = new Formatting {
+ val prompt = Properties.shellPromptString
+ }
+
+ // NOTE: Exposed to repl package since used by SparkExprTyper and SparkILoop
+ private[repl] lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this)
+
+ /**
+ * Determines if errors were reported (typically during compilation).
+ *
+ * @note This is not for runtime errors
+ *
+ * @return True if had errors, otherwise false
+ */
+ @DeveloperApi
+ def isReportingErrors = reporter.hasErrors
+
+ import formatting._
+ import reporter.{ printMessage, withoutTruncating }
+
+ // This exists mostly because using the reporter too early leads to deadlock.
+ private def echo(msg: String) { Console println msg }
+ private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }"))
+ private def _initialize() = {
+ try {
+ // todo. if this crashes, REPL will hang
+ new _compiler.Run() compileSources _initSources
+ _initializeComplete = true
+ true
+ }
+ catch AbstractOrMissingHandler()
+ }
+ private def tquoted(s: String) = "\"\"\"" + s + "\"\"\""
+
+ // argument is a thunk to execute after init is done
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def initialize(postInitSignal: => Unit) {
+ synchronized {
+ if (_isInitialized == null) {
+ _isInitialized = io.spawn {
+ try _initialize()
+ finally postInitSignal
+ }
+ }
+ }
+ }
+
+ /**
+ * Initializes the underlying compiler/interpreter in a blocking fashion.
+ *
+ * @note Must be executed before using SparkIMain!
+ */
+ @DeveloperApi
+ def initializeSynchronous(): Unit = {
+ if (!isInitializeComplete) {
+ _initialize()
+ assert(global != null, global)
+ }
+ }
+ private def isInitializeComplete = _initializeComplete
+
+ /** the public, go through the future compiler */
+
+ /**
+ * The underlying compiler used to generate ASTs and execute code.
+ */
+ @DeveloperApi
+ lazy val global: Global = {
+ if (isInitializeComplete) _compiler
+ else {
+ // If init hasn't been called yet you're on your own.
+ if (_isInitialized == null) {
+ logWarning("Warning: compiler accessed before init set up. Assuming no postInit code.")
+ initialize(())
+ }
+ // // blocks until it is ; false means catastrophic failure
+ if (_isInitialized.get()) _compiler
+ else null
+ }
+ }
+ @deprecated("Use `global` for access to the compiler instance.", "2.9.0")
+ private lazy val compiler: global.type = global
+
+ import global._
+ import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember}
+ import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass}
+
+ private implicit class ReplTypeOps(tp: Type) {
+ def orElse(other: => Type): Type = if (tp ne NoType) tp else other
+ def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp)
+ }
+
+ // TODO: If we try to make naming a lazy val, we run into big time
+ // scalac unhappiness with what look like cycles. It has not been easy to
+ // reduce, but name resolution clearly takes different paths.
+ // NOTE: Exposed to repl package since used by SparkExprTyper
+ private[repl] object naming extends {
+ val global: imain.global.type = imain.global
+ } with Naming {
+ // make sure we don't overwrite their unwisely named res3 etc.
+ def freshUserTermName(): TermName = {
+ val name = newTermName(freshUserVarName())
+ if (definedNameMap contains name) freshUserTermName()
+ else name
+ }
+ def isUserTermName(name: Name) = isUserVarName("" + name)
+ def isInternalTermName(name: Name) = isInternalVarName("" + name)
+ }
+ import naming._
+
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] object deconstruct extends {
+ val global: imain.global.type = imain.global
+ } with StructuredTypeStrings
+
+ // NOTE: Exposed to repl package since used by SparkImports
+ private[repl] lazy val memberHandlers = new {
+ val intp: imain.type = imain
+ } with SparkMemberHandlers
+ import memberHandlers._
+
+ /**
+ * Suppresses overwriting print results during the operation.
+ *
+ * @param body The block to execute
+ * @tparam T The return type of the block
+ *
+ * @return The result from executing the block
+ */
+ @DeveloperApi
+ def beQuietDuring[T](body: => T): T = {
+ val saved = printResults
+ printResults = false
+ try body
+ finally printResults = saved
+ }
+
+ /**
+ * Completely masks all output during the operation (minus JVM standard
+ * out and error).
+ *
+ * @param operation The block to execute
+ * @tparam T The return type of the block
+ *
+ * @return The result from executing the block
+ */
+ @DeveloperApi
+ def beSilentDuring[T](operation: => T): T = {
+ val saved = totalSilence
+ totalSilence = true
+ try operation
+ finally totalSilence = saved
+ }
+
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def quietRun[T](code: String) = beQuietDuring(interpret(code))
+
+ private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = {
+ case t: ControlThrowable => throw t
+ case t: Throwable =>
+ logDebug(label + ": " + unwrap(t))
+ logDebug(stackTraceString(unwrap(t)))
+ alt
+ }
+ /** takes AnyRef because it may be binding a Throwable or an Exceptional */
+
+ private def withLastExceptionLock[T](body: => T, alt: => T): T = {
+ assert(bindExceptions, "withLastExceptionLock called incorrectly.")
+ bindExceptions = false
+
+ try beQuietDuring(body)
+ catch logAndDiscard("withLastExceptionLock", alt)
+ finally bindExceptions = true
+ }
+
+ /**
+ * Contains the code (in string form) representing a wrapper around all
+ * code executed by this instance.
+ *
+ * @return The wrapper code as a string
+ */
+ @DeveloperApi
+ def executionWrapper = _executionWrapper
+
+ /**
+ * Sets the code to use as a wrapper around all code executed by this
+ * instance.
+ *
+ * @param code The wrapper code as a string
+ */
+ @DeveloperApi
+ def setExecutionWrapper(code: String) = _executionWrapper = code
+
+ /**
+ * Clears the code used as a wrapper around all code executed by
+ * this instance.
+ */
+ @DeveloperApi
+ def clearExecutionWrapper() = _executionWrapper = ""
+
+ /** interpreter settings */
+ private lazy val isettings = new SparkISettings(this)
+
+ /**
+ * Instantiates a new compiler used by SparkIMain. Overridable to provide
+ * own instance of a compiler.
+ *
+ * @param settings The settings to provide the compiler
+ * @param reporter The reporter to use for compiler output
+ *
+ * @return The compiler as a Global
+ */
+ @DeveloperApi
+ protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = {
+ settings.outputDirs setSingleOutput virtualDirectory
+ settings.exposeEmptyPackage.value = true
+ new Global(settings, reporter) with ReplGlobal {
+ override def toString: String = ""
+ }
+ }
+
+ /**
+ * Adds any specified jars to the compile and runtime classpaths.
+ *
+ * @note Currently only supports jars, not directories
+ * @param urls The list of items to add to the compile and runtime classpaths
+ */
+ @DeveloperApi
+ def addUrlsToClassPath(urls: URL*): Unit = {
+ new Run // Needed to force initialization of "something" to correctly load Scala classes from jars
+ urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution
+ updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling
+ }
+
+ private def updateCompilerClassPath(urls: URL*): Unit = {
+ require(!global.forMSIL) // Only support JavaPlatform
+
+ val platform = global.platform.asInstanceOf[JavaPlatform]
+
+ val newClassPath = mergeUrlsIntoClassPath(platform, urls: _*)
+
+ // NOTE: Must use reflection until this is exposed/fixed upstream in Scala
+ val fieldSetter = platform.getClass.getMethods
+ .find(_.getName.endsWith("currentClassPath_$eq")).get
+ fieldSetter.invoke(platform, Some(newClassPath))
+
+ // Reload all jars specified into our compiler
+ global.invalidateClassPathEntries(urls.map(_.getPath): _*)
+ }
+
+ private def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = {
+ // Collect our new jars/directories and add them to the existing set of classpaths
+ val allClassPaths = (
+ platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++
+ urls.map(url => {
+ platform.classPath.context.newClassPath(
+ if (url.getProtocol == "file") {
+ val f = new File(url.getPath)
+ if (f.isDirectory)
+ io.AbstractFile.getDirectory(f)
+ else
+ io.AbstractFile.getFile(f)
+ } else {
+ io.AbstractFile.getURL(url)
+ }
+ )
+ })
+ ).distinct
+
+ // Combine all of our classpaths (old and new) into one merged classpath
+ new MergedClassPath(allClassPaths, platform.classPath.context)
+ }
+
+ /**
+ * Represents the parent classloader used by this instance. Can be
+ * overridden to provide alternative classloader.
+ *
+ * @return The classloader used as the parent loader of this instance
+ */
+ @DeveloperApi
+ protected def parentClassLoader: ClassLoader =
+ SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() )
+
+ /* A single class loader is used for all commands interpreted by this Interpreter.
+ It would also be possible to create a new class loader for each command
+ to interpret. The advantages of the current approach are:
+
+ - Expressions are only evaluated one time. This is especially
+ significant for I/O, e.g. "val x = Console.readLine"
+
+ The main disadvantage is:
+
+ - Objects, classes, and methods cannot be rebound. Instead, definitions
+ shadow the old ones, and old code objects refer to the old
+ definitions.
+ */
+ private def resetClassLoader() = {
+ logDebug("Setting new classloader: was " + _classLoader)
+ _classLoader = null
+ ensureClassLoader()
+ }
+ private final def ensureClassLoader() {
+ if (_classLoader == null)
+ _classLoader = makeClassLoader()
+ }
+
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def classLoader: AbstractFileClassLoader = {
+ ensureClassLoader()
+ _classLoader
+ }
+ private class TranslatingClassLoader(parent: ClassLoader) extends AbstractFileClassLoader(virtualDirectory, parent) {
+ /** Overridden here to try translating a simple name to the generated
+ * class name if the original attempt fails. This method is used by
+ * getResourceAsStream as well as findClass.
+ */
+ override protected def findAbstractFile(name: String): AbstractFile = {
+ super.findAbstractFile(name) match {
+ // deadlocks on startup if we try to translate names too early
+ case null if isInitializeComplete =>
+ generatedName(name) map (x => super.findAbstractFile(x)) orNull
+ case file =>
+ file
+ }
+ }
+ }
+ private def makeClassLoader(): AbstractFileClassLoader =
+ new TranslatingClassLoader(parentClassLoader match {
+ case null => ScalaClassLoader fromURLs compilerClasspath
+ case p =>
+ _runtimeClassLoader = new URLClassLoader(compilerClasspath, p) with ExposeAddUrl
+ _runtimeClassLoader
+ })
+
+ private def getInterpreterClassLoader() = classLoader
+
+ // Set the current Java "context" class loader to this interpreter's class loader
+ // NOTE: Exposed to repl package since used by SparkILoopInit
+ private[repl] def setContextClassLoader() = classLoader.setAsContext()
+
+ /**
+ * Returns the real name of a class based on its repl-defined name.
+ *
+ * ==Example==
+ * Given a simple repl-defined name, returns the real name of
+ * the class representing it, e.g. for "Bippy" it may return
+ * {{{
+ * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy
+ * }}}
+ *
+ * @param simpleName The repl-defined name whose real name to retrieve
+ *
+ * @return Some real name if the simple name exists, else None
+ */
+ @DeveloperApi
+ def generatedName(simpleName: String): Option[String] = {
+ if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING)
+ else optFlatName(simpleName)
+ }
+
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def flatName(id: String) = optFlatName(id) getOrElse id
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id)
+
+ /**
+ * Retrieves all simple names contained in the current instance.
+ *
+ * @return A list of sorted names
+ */
+ @DeveloperApi
+ def allDefinedNames = definedNameMap.keys.toList.sorted
+
+ private def pathToType(id: String): String = pathToName(newTypeName(id))
+ // NOTE: Exposed to repl package since used by SparkILoop
+ private[repl] def pathToTerm(id: String): String = pathToName(newTermName(id))
+
+ /**
+ * Retrieves the full code path to access the specified simple name
+ * content.
+ *
+ * @param name The simple name of the target whose path to determine
+ *
+ * @return The full path used to access the specified target (name)
+ */
+ @DeveloperApi
+ def pathToName(name: Name): String = {
+ if (definedNameMap contains name)
+ definedNameMap(name) fullPath name
+ else name.toString
+ }
+
+ /** Most recent tree handled which wasn't wholly synthetic. */
+ private def mostRecentlyHandledTree: Option[Tree] = {
+ prevRequests.reverse foreach { req =>
+ req.handlers.reverse foreach {
+ case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member)
+ case _ => ()
+ }
+ }
+ None
+ }
+
+ /** Stubs for work in progress. */
+ private def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = {
+ for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) {
+ logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2))
+ }
+ }
+
+ private def handleTermRedefinition(name: TermName, old: Request, req: Request) = {
+ for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) {
+ // Printing the types here has a tendency to cause assertion errors, like
+ // assertion failed: fatal: has owner value x, but a class owner is required
+ // so DBG is by-name now to keep it in the family. (It also traps the assertion error,
+ // but we don't want to unnecessarily risk hosing the compiler's internal state.)
+ logDebug("Redefining term '%s'\n %s -> %s".format(name, t1, t2))
+ }
+ }
+
+ private def recordRequest(req: Request) {
+ if (req == null || referencedNameMap == null)
+ return
+
+ prevRequests += req
+ req.referencedNames foreach (x => referencedNameMap(x) = req)
+
+ // warning about serially defining companions. It'd be easy
+ // enough to just redefine them together but that may not always
+ // be what people want so I'm waiting until I can do it better.
+ for {
+ name <- req.definedNames filterNot (x => req.definedNames contains x.companionName)
+ oldReq <- definedNameMap get name.companionName
+ newSym <- req.definedSymbols get name
+ oldSym <- oldReq.definedSymbols get name.companionName
+ if Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }
+ } {
+ afterTyper(replwarn(s"warning: previously defined $oldSym is not a companion to $newSym."))
+ replwarn("Companions must be defined together; you may wish to use :paste mode for this.")
+ }
+
+ // Updating the defined name map
+ req.definedNames foreach { name =>
+ if (definedNameMap contains name) {
+ if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req)
+ else handleTermRedefinition(name.toTermName, definedNameMap(name), req)
+ }
+ definedNameMap(name) = req
+ }
+ }
+
+ private def replwarn(msg: => String) {
+ if (!settings.nowarnings.value)
+ printMessage(msg)
+ }
+
+ private def isParseable(line: String): Boolean = {
+ beSilentDuring {
+ try parse(line) match {
+ case Some(xs) => xs.nonEmpty // parses as-is
+ case None => true // incomplete
+ }
+ catch { case x: Exception => // crashed the compiler
+ replwarn("Exception in isParseable(\"" + line + "\"): " + x)
+ false
+ }
+ }
+ }
+
+ private def compileSourcesKeepingRun(sources: SourceFile*) = {
+ val run = new Run()
+ reporter.reset()
+ run compileSources sources.toList
+ (!reporter.hasErrors, run)
+ }
+
+ /**
+ * Compiles specified source files.
+ *
+ * @param sources The sequence of source files to compile
+ *
+ * @return True if successful, otherwise false
+ */
+ @DeveloperApi
+ def compileSources(sources: SourceFile*): Boolean =
+ compileSourcesKeepingRun(sources: _*)._1
+
+ /**
+ * Compiles a string of code.
+ *
+ * @param code The string of code to compile
+ *
+ * @return True if successful, otherwise false
+ */
+ @DeveloperApi
+ def compileString(code: String): Boolean =
+ compileSources(new BatchSourceFile("