diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000..2b65f6fe3cc80 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +*.bat text eol=crlf +*.cmd text eol=crlf diff --git a/.rat-excludes b/.rat-excludes index b14ad53720f32..d8bee1f8e49c9 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -1,5 +1,6 @@ target .gitignore +.gitattributes .project .classpath .mima-excludes @@ -43,11 +44,13 @@ SparkImports.scala SparkJLineCompletion.scala SparkJLineReader.scala SparkMemberHandlers.scala +SparkReplReporter.scala sbt sbt-launch-lib.bash plugins.sbt work .*\.q +.*\.qv golden test.out/* .*iml diff --git a/LICENSE b/LICENSE index 0517dfb0ab53d..4f2f0e7a7006a 100644 --- a/LICENSE +++ b/LICENSE @@ -712,18 +712,6 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -======================================================================== -For colt: -======================================================================== - -Copyright (c) 1999 CERN - European Organization for Nuclear Research. -Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose is hereby granted without fee, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation. CERN makes no representations about the suitability of this software for any purpose. It is provided "as is" without expressed or implied warranty. - -Packages hep.aida.* - -Written by Pavel Binko, Dino Ferrero Merlino, Wolfgang Hoschek, Tony Johnson, Andreas Pfeiffer, and others. Check the FreeHEP home page for more info. Permission to use and/or redistribute this work is granted under the terms of the LGPL License, with the exception that any usage related to military applications is expressly forbidden. The software and documentation made available under the terms of this license are provided with no warranty. - - ======================================================================== For SnapTree: ======================================================================== @@ -766,7 +754,7 @@ SUCH DAMAGE. ======================================================================== -For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java): +For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java): ======================================================================== Copyright (C) 2008 The Android Open Source Project @@ -783,6 +771,25 @@ See the License for the specific language governing permissions and limitations under the License. +======================================================================== +For LimitedInputStream + (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java): +======================================================================== +Copyright (C) 2007 The Guava Authors + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + + ======================================================================== BSD-style licenses ======================================================================== diff --git a/README.md b/README.md index 8dd8b70696aa2..8d57d50da96c9 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ and Spark Streaming for stream processing. ## Online Documentation You can find the latest Spark documentation, including a programming -guide, on the [project web page](http://spark.apache.org/documentation.html). +guide, on the [project web page](http://spark.apache.org/documentation.html) +and [project wiki](https://cwiki.apache.org/confluence/display/SPARK). This README file only contains basic setup instructions. ## Building Spark @@ -25,7 +26,7 @@ To build Spark and its example programs, run: (You do not need to do this if you downloaded a pre-built package.) More detailed documentation is available from the project site, at -["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html). +["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-with-maven.html). ## Interactive Scala Shell @@ -84,7 +85,7 @@ storage systems. Because the protocols have changed in different versions of Hadoop, you must build Spark against the same version that your cluster runs. Please refer to the build documentation at -["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version) +["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-with-maven.html#specifying-the-hadoop-version) for detailed guidance on building for a particular distribution of Hadoop, including building for particular Hive and Hive Thriftserver distributions. See also ["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html) diff --git a/assembly/pom.xml b/assembly/pom.xml index 31a01e4d8e1de..4e2b773e7d2f3 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -66,22 +66,22 @@ org.apache.spark - spark-repl_${scala.binary.version} + spark-streaming_${scala.binary.version} ${project.version} org.apache.spark - spark-streaming_${scala.binary.version} + spark-graphx_${scala.binary.version} ${project.version} org.apache.spark - spark-graphx_${scala.binary.version} + spark-sql_${scala.binary.version} ${project.version} org.apache.spark - spark-sql_${scala.binary.version} + spark-repl_${scala.binary.version} ${project.version} @@ -197,6 +197,11 @@ spark-hive_${scala.binary.version} ${project.version} + + + + hive-thriftserver + org.apache.spark spark-hive-thriftserver_${scala.binary.version} diff --git a/bagel/pom.xml b/bagel/pom.xml index 93db0d5efda5f..0327ffa402671 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd index 9b9e40321ea93..a4c099fb45b14 100644 --- a/bin/compute-classpath.cmd +++ b/bin/compute-classpath.cmd @@ -1,117 +1,117 @@ -@echo off - -rem -rem Licensed to the Apache Software Foundation (ASF) under one or more -rem contributor license agreements. See the NOTICE file distributed with -rem this work for additional information regarding copyright ownership. -rem The ASF licenses this file to You under the Apache License, Version 2.0 -rem (the "License"); you may not use this file except in compliance with -rem the License. You may obtain a copy of the License at -rem -rem http://www.apache.org/licenses/LICENSE-2.0 -rem -rem Unless required by applicable law or agreed to in writing, software -rem distributed under the License is distributed on an "AS IS" BASIS, -rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -rem See the License for the specific language governing permissions and -rem limitations under the License. -rem - -rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" -rem script and the ExecutorRunner in standalone cluster mode. - -rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting -rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we -rem need to set it here because we use !datanucleus_jars! below. -if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion -setlocal enabledelayedexpansion -:skip_delayed_expansion - -set SCALA_VERSION=2.10 - -rem Figure out where the Spark framework is installed -set FWDIR=%~dp0..\ - -rem Load environment variables from conf\spark-env.cmd, if it exists -if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" - -rem Build up classpath -set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% - -if "x%SPARK_CONF_DIR%"!="x" ( - set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% -) else ( - set CLASSPATH=%CLASSPATH%;%FWDIR%conf -) - -if exist "%FWDIR%RELEASE" ( - for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) else ( - for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( - set ASSEMBLY_JAR=%%d - ) -) - -set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR% - -rem When Hive support is needed, Datanucleus jars must be included on the classpath. -rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. -rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is -rem built with Hive, so look for them there. -if exist "%FWDIR%RELEASE" ( - set datanucleus_dir=%FWDIR%lib -) else ( - set datanucleus_dir=%FWDIR%lib_managed\jars -) -set "datanucleus_jars=" -for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do ( - set datanucleus_jars=!datanucleus_jars!;%%d -) -set CLASSPATH=%CLASSPATH%;%datanucleus_jars% - -set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes -set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes - -set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes -set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes - -if "x%SPARK_TESTING%"=="x1" ( - rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH - rem so that local compilation takes precedence over assembled jar - set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH% -) - -rem Add hadoop conf dir - else FileSystem.*, etc fail -rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts -rem the configurtion files. -if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir - set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% -:no_hadoop_conf_dir - -if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir - set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% -:no_yarn_conf_dir - -rem A bit of a hack to allow calling this script within run2.cmd without seeing output -if "%DONT_PRINT_CLASSPATH%"=="1" goto exit - -echo %CLASSPATH% - -:exit +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run" +rem script and the ExecutorRunner in standalone cluster mode. + +rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting +rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we +rem need to set it here because we use !datanucleus_jars! below. +if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion +setlocal enabledelayedexpansion +:skip_delayed_expansion + +set SCALA_VERSION=2.10 + +rem Figure out where the Spark framework is installed +set FWDIR=%~dp0..\ + +rem Load environment variables from conf\spark-env.cmd, if it exists +if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd" + +rem Build up classpath +set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH% + +if not "x%SPARK_CONF_DIR%"=="x" ( + set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR% +) else ( + set CLASSPATH=%CLASSPATH%;%FWDIR%conf +) + +if exist "%FWDIR%RELEASE" ( + for %%d in ("%FWDIR%lib\spark-assembly*.jar") do ( + set ASSEMBLY_JAR=%%d + ) +) else ( + for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do ( + set ASSEMBLY_JAR=%%d + ) +) + +set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR% + +rem When Hive support is needed, Datanucleus jars must be included on the classpath. +rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. +rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is +rem built with Hive, so look for them there. +if exist "%FWDIR%RELEASE" ( + set datanucleus_dir=%FWDIR%lib +) else ( + set datanucleus_dir=%FWDIR%lib_managed\jars +) +set "datanucleus_jars=" +for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do ( + set datanucleus_jars=!datanucleus_jars!;%%d +) +set CLASSPATH=%CLASSPATH%;%datanucleus_jars% + +set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes +set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes + +set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes +set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes + +if "x%SPARK_TESTING%"=="x1" ( + rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH + rem so that local compilation takes precedence over assembled jar + set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH% +) + +rem Add hadoop conf dir - else FileSystem.*, etc fail +rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts +rem the configurtion files. +if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir + set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR% +:no_hadoop_conf_dir + +if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir + set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR% +:no_yarn_conf_dir + +rem A bit of a hack to allow calling this script within run2.cmd without seeing output +if "%DONT_PRINT_CLASSPATH%"=="1" goto exit + +echo %CLASSPATH% + +:exit diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 905bbaf99b374..298641f2684de 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -20,8 +20,6 @@ # This script computes Spark's classpath and prints it to stdout; it's used by both the "run" # script and the ExecutorRunner in standalone cluster mode. -SCALA_VERSION=2.10 - # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" @@ -36,7 +34,7 @@ else CLASSPATH="$CLASSPATH:$FWDIR/conf" fi -ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION" +ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION" if [ -n "$JAVA_HOME" ]; then JAR_CMD="$JAVA_HOME/bin/jar" @@ -48,19 +46,19 @@ fi if [ -n "$SPARK_PREPEND_CLASSES" ]; then echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\ "classes ahead of assembly." >&2 - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes" - CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes" + CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes" fi # Use spark-assembly jar from either RELEASE or assembly directory @@ -123,15 +121,15 @@ fi # Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 if [[ $SPARK_TESTING == 1 ]]; then - CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes" - CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes" + CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes" fi # Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail ! diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 6d4231b204595..356b3d49b2ffe 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -36,3 +36,23 @@ if [ -z "$SPARK_ENV_LOADED" ]; then set +a fi fi + +# Setting SPARK_SCALA_VERSION if not already set. + +if [ -z "$SPARK_SCALA_VERSION" ]; then + + ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11" + ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10" + + if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then + echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2 + echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2 + exit 1 + fi + + if [ -d "$ASSEMBLY_DIR2" ]; then + export SPARK_SCALA_VERSION="2.11" + else + export SPARK_SCALA_VERSION="2.10" + fi +fi diff --git a/bin/pyspark b/bin/pyspark index 6655725ef8e8e..0b4f695dd06dd 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,7 +25,7 @@ export SPARK_HOME="$FWDIR" source "$FWDIR/bin/utils.sh" -SCALA_VERSION=2.10 +source "$FWDIR"/bin/load-spark-env.sh function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 @@ -40,7 +40,7 @@ fi # Exit if the user hasn't compiled Spark if [ ! -f "$FWDIR/RELEASE" ]; then # Exit if the user hasn't compiled Spark - ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null + ls "$FWDIR"/assembly/target/scala-$SPARK_SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null if [[ $? != 0 ]]; then echo "Failed to find Spark assembly in $FWDIR/assembly/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 @@ -48,24 +48,47 @@ if [ ! -f "$FWDIR/RELEASE" ]; then fi fi -. "$FWDIR"/bin/load-spark-env.sh +# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython` +# executable, while the worker would still be launched using PYSPARK_PYTHON. +# +# In Spark 1.2, we removed the documentation of the IPYTHON and IPYTHON_OPTS variables and added +# PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS to allow IPython to be used for the driver. +# Now, users can simply set PYSPARK_DRIVER_PYTHON=ipython to use IPython and set +# PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver +# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython +# and executor Python executables. +# +# For backwards-compatibility, we retain the old IPYTHON and IPYTHON_OPTS variables. + +# Determine the Python executable to use if PYSPARK_PYTHON or PYSPARK_DRIVER_PYTHON isn't set: +if hash python2.7 2>/dev/null; then + # Attempt to use Python 2.7, if installed: + DEFAULT_PYTHON="python2.7" +else + DEFAULT_PYTHON="python" +fi -# Figure out which Python executable to use +# Determine the Python executable to use for the driver: +if [[ -n "$IPYTHON_OPTS" || "$IPYTHON" == "1" ]]; then + # If IPython options are specified, assume user wants to run IPython + # (for backwards-compatibility) + PYSPARK_DRIVER_PYTHON_OPTS="$PYSPARK_DRIVER_PYTHON_OPTS $IPYTHON_OPTS" + PYSPARK_DRIVER_PYTHON="ipython" +elif [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then + PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}" +fi + +# Determine the Python executable to use for the executors: if [[ -z "$PYSPARK_PYTHON" ]]; then - if [[ "$IPYTHON" = "1" || -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON="ipython" + if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then + echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2 + exit 1 else - PYSPARK_PYTHON="python" + PYSPARK_PYTHON="$DEFAULT_PYTHON" fi fi export PYSPARK_PYTHON -if [[ -z "$PYSPARK_PYTHON_OPTS" && -n "$IPYTHON_OPTS" ]]; then - # for backward compatibility - PYSPARK_PYTHON_OPTS="$IPYTHON_OPTS" -fi - # Add the PySpark classes to the Python path: export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" @@ -93,9 +116,9 @@ if [[ -n "$SPARK_TESTING" ]]; then unset YARN_CONF_DIR unset HADOOP_CONF_DIR if [[ -n "$PYSPARK_DOC_TEST" ]]; then - exec "$PYSPARK_PYTHON" -m doctest $1 + exec "$PYSPARK_DRIVER_PYTHON" -m doctest $1 else - exec "$PYSPARK_PYTHON" $1 + exec "$PYSPARK_DRIVER_PYTHON" $1 fi exit fi @@ -109,7 +132,5 @@ if [[ "$1" =~ \.py$ ]]; then gatherSparkSubmitOpts "$@" exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}" else - # PySpark shell requires special handling downstream - export PYSPARK_SHELL=1 - exec "$PYSPARK_PYTHON" $PYSPARK_PYTHON_OPTS + exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS fi diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index a0e66abcc26c9..a542ec80b49d6 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -59,7 +59,11 @@ for /f %%i in ('echo %1^| findstr /R "\.py"') do ( ) if [%PYTHON_FILE%] == [] ( - %PYSPARK_PYTHON% + if [%IPYTHON%] == [1] ( + ipython %IPYTHON_OPTS% + ) else ( + %PYSPARK_PYTHON% + ) ) else ( echo. echo WARNING: Running python applications through ./bin/pyspark.cmd is deprecated as of Spark 1.0. diff --git a/bin/run-example b/bin/run-example index 34dd71c71880e..3d932509426fc 100755 --- a/bin/run-example +++ b/bin/run-example @@ -17,12 +17,12 @@ # limitations under the License. # -SCALA_VERSION=2.10 - FWDIR="$(cd "`dirname "$0"`"/..; pwd)" export SPARK_HOME="$FWDIR" EXAMPLES_DIR="$FWDIR"/examples +. "$FWDIR"/bin/load-spark-env.sh + if [ -n "$1" ]; then EXAMPLE_CLASS="$1" shift @@ -36,8 +36,8 @@ fi if [ -f "$FWDIR/RELEASE" ]; then export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`" -elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then - export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`" +elif [ -e "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar ]; then + export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar`" fi if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then diff --git a/bin/spark-class b/bin/spark-class index e8201c18d52de..0d58d95c1aee3 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -24,8 +24,6 @@ case "`uname`" in CYGWIN*) cygwin=true;; esac -SCALA_VERSION=2.10 - # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" @@ -81,7 +79,11 @@ case "$1" in OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS" OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM} if [ -n "$SPARK_SUBMIT_LIBRARY_PATH" ]; then - OUR_JAVA_OPTS="$OUR_JAVA_OPTS -Djava.library.path=$SPARK_SUBMIT_LIBRARY_PATH" + if [[ $OSTYPE == darwin* ]]; then + export DYLD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$DYLD_LIBRARY_PATH" + else + export LD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$LD_LIBRARY_PATH" + fi fi if [ -n "$SPARK_SUBMIT_DRIVER_MEMORY" ]; then OUR_JAVA_MEM="$SPARK_SUBMIT_DRIVER_MEMORY" @@ -105,7 +107,7 @@ else exit 1 fi fi -JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') +JAVA_VERSION=$("$RUNNER" -version 2>&1 | grep 'version' | sed 's/.* version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') # Set JAVA_OPTS to be able to load native libraries and to set heap size if [ "$JAVA_VERSION" -ge 18 ]; then @@ -124,9 +126,9 @@ fi TOOLS_DIR="$FWDIR"/tools SPARK_TOOLS_JAR="" -if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then +if [ -e "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the SBT build - export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar`" + export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar`" fi if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the Maven build @@ -145,7 +147,7 @@ fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then if test -z "$SPARK_TOOLS_JAR"; then - echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2 + echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2 echo "You need to build Spark before running $1." 1>&2 exit 1 fi diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 2ee60b4e2a2b3..8f90ba5a0b3b8 100755 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -17,6 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -set SPARK_HOME=%~dp0.. +rem This is the entry point for running Spark shell. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell +cmd /V /E /C %~dp0spark-shell2.cmd %* diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd new file mode 100644 index 0000000000000..2ee60b4e2a2b3 --- /dev/null +++ b/bin/spark-shell2.cmd @@ -0,0 +1,22 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +set SPARK_HOME=%~dp0.. + +cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell diff --git a/bin/spark-submit b/bin/spark-submit index c557311b4b20e..f92d90c3a66b0 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -22,6 +22,9 @@ export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" ORIG_ARGS=("$@") +# Set COLUMNS for progress bar +export COLUMNS=`tput cols` + while (($#)); do if [ "$1" = "--deploy-mode" ]; then SPARK_SUBMIT_DEPLOY_MODE=$2 diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index cf6046d1547ad..8f3b84c7b971d 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -17,52 +17,7 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! +rem This is the entry point for running Spark submit. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. -set SPARK_HOME=%~dp0.. -set ORIG_ARGS=%* - -rem Reset the values of all variables used -set SPARK_SUBMIT_DEPLOY_MODE=client -set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf -set SPARK_SUBMIT_DRIVER_MEMORY= -set SPARK_SUBMIT_LIBRARY_PATH= -set SPARK_SUBMIT_CLASSPATH= -set SPARK_SUBMIT_OPTS= -set SPARK_SUBMIT_BOOTSTRAP_DRIVER= - -:loop -if [%1] == [] goto continue - if [%1] == [--deploy-mode] ( - set SPARK_SUBMIT_DEPLOY_MODE=%2 - ) else if [%1] == [--properties-file] ( - set SPARK_SUBMIT_PROPERTIES_FILE=%2 - ) else if [%1] == [--driver-memory] ( - set SPARK_SUBMIT_DRIVER_MEMORY=%2 - ) else if [%1] == [--driver-library-path] ( - set SPARK_SUBMIT_LIBRARY_PATH=%2 - ) else if [%1] == [--driver-class-path] ( - set SPARK_SUBMIT_CLASSPATH=%2 - ) else if [%1] == [--driver-java-options] ( - set SPARK_SUBMIT_OPTS=%2 - ) - shift -goto loop -:continue - -rem For client mode, the driver will be launched in the same JVM that launches -rem SparkSubmit, so we may need to read the properties file for any extra class -rem paths, library paths, java options and memory early on. Otherwise, it will -rem be too late by the time the driver JVM has started. - -if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( - if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( - rem Parse the properties file only if the special configs exist - for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ - %SPARK_SUBMIT_PROPERTIES_FILE%') do ( - set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 - ) - ) -) - -cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% +cmd /V /E /C %~dp0spark-submit2.cmd %* diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd new file mode 100644 index 0000000000000..cf6046d1547ad --- /dev/null +++ b/bin/spark-submit2.cmd @@ -0,0 +1,68 @@ +@echo off + +rem +rem Licensed to the Apache Software Foundation (ASF) under one or more +rem contributor license agreements. See the NOTICE file distributed with +rem this work for additional information regarding copyright ownership. +rem The ASF licenses this file to You under the Apache License, Version 2.0 +rem (the "License"); you may not use this file except in compliance with +rem the License. You may obtain a copy of the License at +rem +rem http://www.apache.org/licenses/LICENSE-2.0 +rem +rem Unless required by applicable law or agreed to in writing, software +rem distributed under the License is distributed on an "AS IS" BASIS, +rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +rem See the License for the specific language governing permissions and +rem limitations under the License. +rem + +rem NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! + +set SPARK_HOME=%~dp0.. +set ORIG_ARGS=%* + +rem Reset the values of all variables used +set SPARK_SUBMIT_DEPLOY_MODE=client +set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf +set SPARK_SUBMIT_DRIVER_MEMORY= +set SPARK_SUBMIT_LIBRARY_PATH= +set SPARK_SUBMIT_CLASSPATH= +set SPARK_SUBMIT_OPTS= +set SPARK_SUBMIT_BOOTSTRAP_DRIVER= + +:loop +if [%1] == [] goto continue + if [%1] == [--deploy-mode] ( + set SPARK_SUBMIT_DEPLOY_MODE=%2 + ) else if [%1] == [--properties-file] ( + set SPARK_SUBMIT_PROPERTIES_FILE=%2 + ) else if [%1] == [--driver-memory] ( + set SPARK_SUBMIT_DRIVER_MEMORY=%2 + ) else if [%1] == [--driver-library-path] ( + set SPARK_SUBMIT_LIBRARY_PATH=%2 + ) else if [%1] == [--driver-class-path] ( + set SPARK_SUBMIT_CLASSPATH=%2 + ) else if [%1] == [--driver-java-options] ( + set SPARK_SUBMIT_OPTS=%2 + ) + shift +goto loop +:continue + +rem For client mode, the driver will be launched in the same JVM that launches +rem SparkSubmit, so we may need to read the properties file for any extra class +rem paths, library paths, java options and memory early on. Otherwise, it will +rem be too late by the time the driver JVM has started. + +if [%SPARK_SUBMIT_DEPLOY_MODE%] == [client] ( + if exist %SPARK_SUBMIT_PROPERTIES_FILE% ( + rem Parse the properties file only if the special configs exist + for /f %%i in ('findstr /r /c:"^[\t ]*spark.driver.memory" /c:"^[\t ]*spark.driver.extra" ^ + %SPARK_SUBMIT_PROPERTIES_FILE%') do ( + set SPARK_SUBMIT_BOOTSTRAP_DRIVER=1 + ) + ) +) + +cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.spark.deploy.SparkSubmit %ORIG_ARGS% diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template index f8ffbf64278fb..0886b0276fb90 100755 --- a/conf/spark-env.sh.template +++ b/conf/spark-env.sh.template @@ -28,7 +28,7 @@ # - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job. # - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job. -# Options for the daemons used in the standalone deploy mode: +# Options for the daemons used in the standalone deploy mode # - SPARK_MASTER_IP, to bind the master to a different IP address or hostname # - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports for the master # - SPARK_MASTER_OPTS, to set config properties only for the master (e.g. "-Dx=y") @@ -41,3 +41,10 @@ # - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y") # - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y") # - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers + +# Generic options for the daemons used in the standalone deploy mode +# - SPARK_CONF_DIR Alternate conf dir. (Default: ${SPARK_HOME}/conf) +# - SPARK_LOG_DIR Where log files are stored. (Default: ${SPARK_HOME}/logs) +# - SPARK_PID_DIR Where the pid file is stored. (Default: /tmp) +# - SPARK_IDENT_STRING A string representing this instance of spark. (Default: $USER) +# - SPARK_NICENESS The scheduling priority for daemons. (Default: 0) diff --git a/core/pom.xml b/core/pom.xml index a5a178079bc57..1feb00b3a7fb8 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -34,6 +34,34 @@ Spark Project Core http://spark.apache.org/ + + com.twitter + chill_${scala.binary.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + + + com.twitter + chill-java + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + org.apache.hadoop hadoop-client @@ -44,6 +72,16 @@ + + org.apache.spark + spark-network-common_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-network-shuffle_${scala.binary.version} + ${project.version} + net.java.dev.jets3t jets3t @@ -85,8 +123,6 @@ org.apache.commons commons-math3 - 3.3 - test com.google.code.findbugs @@ -125,12 +161,8 @@ lz4 - com.twitter - chill_${scala.binary.version} - - - com.twitter - chill-java + org.roaringbitmap + RoaringBitmap commons-net @@ -158,10 +190,6 @@ json4s-jackson_${scala.binary.version} 3.2.10 - - colt - colt - org.apache.mesos mesos @@ -243,6 +271,11 @@ + + org.seleniumhq.selenium + selenium-java + test + org.scalatest scalatest_${scala.binary.version} @@ -296,14 +329,16 @@ org.scalatest scalatest-maven-plugin - - - ${basedir}/.. - 1 - ${spark.classpath} - - + + + test + + test + + + + org.apache.maven.plugins @@ -411,4 +446,5 @@ + diff --git a/core/src/main/java/org/apache/spark/JobExecutionStatus.java b/core/src/main/java/org/apache/spark/JobExecutionStatus.java new file mode 100644 index 0000000000000..6e161313702bb --- /dev/null +++ b/core/src/main/java/org/apache/spark/JobExecutionStatus.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +public enum JobExecutionStatus { + RUNNING, + SUCCEEDED, + FAILED, + UNKNOWN +} diff --git a/core/src/main/java/org/apache/spark/SparkJobInfo.java b/core/src/main/java/org/apache/spark/SparkJobInfo.java new file mode 100644 index 0000000000000..4e3c983b1170a --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkJobInfo.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +/** + * Exposes information about Spark Jobs. + * + * This interface is not designed to be implemented outside of Spark. We may add additional methods + * which may break binary compatibility with outside implementations. + */ +public interface SparkJobInfo { + int jobId(); + int[] stageIds(); + JobExecutionStatus status(); +} diff --git a/core/src/main/java/org/apache/spark/SparkStageInfo.java b/core/src/main/java/org/apache/spark/SparkStageInfo.java new file mode 100644 index 0000000000000..fd74321093658 --- /dev/null +++ b/core/src/main/java/org/apache/spark/SparkStageInfo.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +/** + * Exposes information about Spark Stages. + * + * This interface is not designed to be implemented outside of Spark. We may add additional methods + * which may break binary compatibility with outside implementations. + */ +public interface SparkStageInfo { + int stageId(); + int currentAttemptId(); + long submissionTime(); + String name(); + int numTasks(); + int numActiveTasks(); + int numCompletedTasks(); + int numFailedTasks(); +} diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 4e6d708af0ea7..0d6973203eba1 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -18,252 +18,89 @@ package org.apache.spark; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import scala.Function0; import scala.Function1; import scala.Unit; -import scala.collection.JavaConversions; import org.apache.spark.annotation.DeveloperApi; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.util.TaskCompletionListener; -import org.apache.spark.util.TaskCompletionListenerException; /** -* :: DeveloperApi :: -* Contextual information about a task which can be read or mutated during execution. -*/ -@DeveloperApi -public class TaskContext implements Serializable { - - private int stageId; - private int partitionId; - private long attemptId; - private boolean runningLocally; - private TaskMetrics taskMetrics; - - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - * @param taskMetrics performance metrics of the task - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally, - TaskMetrics taskMetrics) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = taskMetrics; - } - - /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); - } - + * Contextual information about a task which can be read or mutated during + * execution. To access the TaskContext for a running task use + * TaskContext.get(). + */ +public abstract class TaskContext implements Serializable { /** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task + * Return the currently active TaskContext. This can be called inside of + * user functions to access contextual information about running tasks. */ - @DeveloperApi - public TaskContext(int stageId, int partitionId, long attemptId) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = false; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); + public static TaskContext get() { + return taskContext.get(); } private static ThreadLocal taskContext = new ThreadLocal(); - /** - * :: Internal API :: - * This is spark internal API, not intended to be called from user programs. - */ - public static void setTaskContext(TaskContext tc) { + static void setTaskContext(TaskContext tc) { taskContext.set(tc); } - public static TaskContext get() { - return taskContext.get(); - } - - /** :: Internal API :: */ - public static void unset() { + static void unset() { taskContext.remove(); } - // List of callback functions to execute when the task completes. - private transient List onCompleteCallbacks = - new ArrayList(); - - // Whether the corresponding task has been killed. - private volatile boolean interrupted = false; - - // Whether the task has completed. - private volatile boolean completed = false; - /** - * Checks whether the task has completed. + * Whether the task has completed. */ - public boolean isCompleted() { - return completed; - } + public abstract boolean isCompleted(); /** - * Checks whether the task has been killed. + * Whether the task has been killed. */ - public boolean isInterrupted() { - return interrupted; - } + public abstract boolean isInterrupted(); + + /** @deprecated: use isRunningLocally() */ + @Deprecated + public abstract boolean runningLocally(); + + public abstract boolean isRunningLocally(); /** * Add a (Java friendly) listener to be executed on task completion. * This will be called in all situation - success, failure, or cancellation. - *

* An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { - onCompleteCallbacks.add(listener); - return this; - } + public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener); /** * Add a listener in the form of a Scala closure to be executed on task completion. * This will be called in all situations - success, failure, or cancellation. - *

* An example use is for HadoopRDD to register a callback to close the input stream. */ - public TaskContext addTaskCompletionListener(final Function1 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(context); - } - }); - return this; - } + public abstract TaskContext addTaskCompletionListener(final Function1 f); /** * Add a callback function to be executed on task completion. An example use * is for HadoopRDD to register a callback to close the input stream. * Will be called in any situation - success, failure, or cancellation. * - * Deprecated: use addTaskCompletionListener - * + * @deprecated: use addTaskCompletionListener + * * @param f Callback function. */ @Deprecated - public void addOnCompleteCallback(final Function0 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(); - } - }); - } - - /** - * ::Internal API:: - * Marks the task as completed and triggers the listeners. - */ - public void markTaskCompleted() throws TaskCompletionListenerException { - completed = true; - List errorMsgs = new ArrayList(2); - // Process complete callbacks in the reverse order of registration - List revlist = - new ArrayList(onCompleteCallbacks); - Collections.reverse(revlist); - for (TaskCompletionListener tcl: revlist) { - try { - tcl.onTaskCompletion(this); - } catch (Throwable e) { - errorMsgs.add(e.getMessage()); - } - } - - if (!errorMsgs.isEmpty()) { - throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); - } - } - - /** - * ::Internal API:: - * Marks the task for interruption, i.e. cancellation. - */ - public void markInterrupted() { - interrupted = true; - } - - @Deprecated - /** Deprecated: use getStageId() */ - public int stageId() { - return stageId; - } - - @Deprecated - /** Deprecated: use getPartitionId() */ - public int partitionId() { - return partitionId; - } - - @Deprecated - /** Deprecated: use getAttemptId() */ - public long attemptId() { - return attemptId; - } - - @Deprecated - /** Deprecated: use isRunningLocally() */ - public boolean runningLocally() { - return runningLocally; - } - - public boolean isRunningLocally() { - return runningLocally; - } + public abstract void addOnCompleteCallback(final Function0 f); - public int getStageId() { - return stageId; - } + public abstract int stageId(); - public int getPartitionId() { - return partitionId; - } + public abstract int partitionId(); - public long getAttemptId() { - return attemptId; - } + public abstract long attemptId(); - /** ::Internal API:: */ - public TaskMetrics taskMetrics() { - return taskMetrics; - } + /** ::DeveloperApi:: */ + @DeveloperApi + public abstract TaskMetrics taskMetrics(); } diff --git a/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java new file mode 100644 index 0000000000000..0ad189633e427 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/JavaFutureAction.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + + +import java.util.List; +import java.util.concurrent.Future; + +public interface JavaFutureAction extends Future { + + /** + * Returns the job IDs run by the underlying async operation. + * + * This returns the current snapshot of the job list. Certain operations may run multiple + * jobs, so multiple calls to this method may return different lists. + */ + List jobIds(); +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java index abd9bcc07ac61..99bf240a17225 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java @@ -22,7 +22,8 @@ import scala.Tuple2; /** - * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs. + * A function that returns key-value pairs (Tuple2<K, V>), and can be used to + * construct PairRDDs. */ public interface PairFunction extends Serializable { public Tuple2 call(T t) throws Exception; diff --git a/core/src/main/java/org/apache/spark/api/java/function/package.scala b/core/src/main/java/org/apache/spark/api/java/function/package.scala index 7f91de653a64a..0f9bac7164162 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/package.scala +++ b/core/src/main/java/org/apache/spark/api/java/function/package.scala @@ -22,4 +22,4 @@ package org.apache.spark.api.java * these interfaces to pass functions to various Java API methods for Spark. Please visit Spark's * Java programming guide for more details. */ -package object function \ No newline at end of file +package object function diff --git a/core/src/main/java/org/apache/spark/util/collection/Sorter.java b/core/src/main/java/org/apache/spark/util/collection/Sorter.java deleted file mode 100644 index 64ad18c0e463a..0000000000000 --- a/core/src/main/java/org/apache/spark/util/collection/Sorter.java +++ /dev/null @@ -1,915 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util.collection; - -import java.util.Comparator; - -/** - * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort." - * See the method comment on sort() for more details. - * - * This has been kept in Java with the original style in order to match very closely with the - * Anroid source code, and thus be easy to verify correctness. - * - * The purpose of the port is to generalize the interface to the sort to accept input data formats - * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap - * uses this to sort an Array with alternating elements of the form [key, value, key, value]. - * This generalization comes with minimal overhead -- see SortDataFormat for more information. - */ -class Sorter { - - /** - * This is the minimum sized sequence that will be merged. Shorter - * sequences will be lengthened by calling binarySort. If the entire - * array is less than this length, no merges will be performed. - * - * This constant should be a power of two. It was 64 in Tim Peter's C - * implementation, but 32 was empirically determined to work better in - * this implementation. In the unlikely event that you set this constant - * to be a number that's not a power of two, you'll need to change the - * minRunLength computation. - * - * If you decrease this constant, you must change the stackLen - * computation in the TimSort constructor, or you risk an - * ArrayOutOfBounds exception. See listsort.txt for a discussion - * of the minimum stack length required as a function of the length - * of the array being sorted and the minimum merge sequence length. - */ - private static final int MIN_MERGE = 32; - - private final SortDataFormat s; - - public Sorter(SortDataFormat sortDataFormat) { - this.s = sortDataFormat; - } - - /** - * A stable, adaptive, iterative mergesort that requires far fewer than - * n lg(n) comparisons when running on partially sorted arrays, while - * offering performance comparable to a traditional mergesort when run - * on random arrays. Like all proper mergesorts, this sort is stable and - * runs O(n log n) time (worst case). In the worst case, this sort requires - * temporary storage space for n/2 object references; in the best case, - * it requires only a small constant amount of space. - * - * This implementation was adapted from Tim Peters's list sort for - * Python, which is described in detail here: - * - * http://svn.python.org/projects/python/trunk/Objects/listsort.txt - * - * Tim's C code may be found here: - * - * http://svn.python.org/projects/python/trunk/Objects/listobject.c - * - * The underlying techniques are described in this paper (and may have - * even earlier origins): - * - * "Optimistic Sorting and Information Theoretic Complexity" - * Peter McIlroy - * SODA (Fourth Annual ACM-SIAM Symposium on Discrete Algorithms), - * pp 467-474, Austin, Texas, 25-27 January 1993. - * - * While the API to this class consists solely of static methods, it is - * (privately) instantiable; a TimSort instance holds the state of an ongoing - * sort, assuming the input array is large enough to warrant the full-blown - * TimSort. Small arrays are sorted in place, using a binary insertion sort. - * - * @author Josh Bloch - */ - void sort(Buffer a, int lo, int hi, Comparator c) { - assert c != null; - - int nRemaining = hi - lo; - if (nRemaining < 2) - return; // Arrays of size 0 and 1 are always sorted - - // If array is small, do a "mini-TimSort" with no merges - if (nRemaining < MIN_MERGE) { - int initRunLen = countRunAndMakeAscending(a, lo, hi, c); - binarySort(a, lo, hi, lo + initRunLen, c); - return; - } - - /** - * March over the array once, left to right, finding natural runs, - * extending short natural runs to minRun elements, and merging runs - * to maintain stack invariant. - */ - SortState sortState = new SortState(a, c, hi - lo); - int minRun = minRunLength(nRemaining); - do { - // Identify next run - int runLen = countRunAndMakeAscending(a, lo, hi, c); - - // If run is short, extend to min(minRun, nRemaining) - if (runLen < minRun) { - int force = nRemaining <= minRun ? nRemaining : minRun; - binarySort(a, lo, lo + force, lo + runLen, c); - runLen = force; - } - - // Push run onto pending-run stack, and maybe merge - sortState.pushRun(lo, runLen); - sortState.mergeCollapse(); - - // Advance to find next run - lo += runLen; - nRemaining -= runLen; - } while (nRemaining != 0); - - // Merge all remaining runs to complete sort - assert lo == hi; - sortState.mergeForceCollapse(); - assert sortState.stackSize == 1; - } - - /** - * Sorts the specified portion of the specified array using a binary - * insertion sort. This is the best method for sorting small numbers - * of elements. It requires O(n log n) compares, but O(n^2) data - * movement (worst case). - * - * If the initial part of the specified range is already sorted, - * this method can take advantage of it: the method assumes that the - * elements from index {@code lo}, inclusive, to {@code start}, - * exclusive are already sorted. - * - * @param a the array in which a range is to be sorted - * @param lo the index of the first element in the range to be sorted - * @param hi the index after the last element in the range to be sorted - * @param start the index of the first element in the range that is - * not already known to be sorted ({@code lo <= start <= hi}) - * @param c comparator to used for the sort - */ - @SuppressWarnings("fallthrough") - private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { - assert lo <= start && start <= hi; - if (start == lo) - start++; - - Buffer pivotStore = s.allocate(1); - for ( ; start < hi; start++) { - s.copyElement(a, start, pivotStore, 0); - K pivot = s.getKey(pivotStore, 0); - - // Set left (and right) to the index where a[start] (pivot) belongs - int left = lo; - int right = start; - assert left <= right; - /* - * Invariants: - * pivot >= all in [lo, left). - * pivot < all in [right, start). - */ - while (left < right) { - int mid = (left + right) >>> 1; - if (c.compare(pivot, s.getKey(a, mid)) < 0) - right = mid; - else - left = mid + 1; - } - assert left == right; - - /* - * The invariants still hold: pivot >= all in [lo, left) and - * pivot < all in [left, start), so pivot belongs at left. Note - * that if there are elements equal to pivot, left points to the - * first slot after them -- that's why this sort is stable. - * Slide elements over to make room for pivot. - */ - int n = start - left; // The number of elements to move - // Switch is just an optimization for arraycopy in default case - switch (n) { - case 2: s.copyElement(a, left + 1, a, left + 2); - case 1: s.copyElement(a, left, a, left + 1); - break; - default: s.copyRange(a, left, a, left + 1, n); - } - s.copyElement(pivotStore, 0, a, left); - } - } - - /** - * Returns the length of the run beginning at the specified position in - * the specified array and reverses the run if it is descending (ensuring - * that the run will always be ascending when the method returns). - * - * A run is the longest ascending sequence with: - * - * a[lo] <= a[lo + 1] <= a[lo + 2] <= ... - * - * or the longest descending sequence with: - * - * a[lo] > a[lo + 1] > a[lo + 2] > ... - * - * For its intended use in a stable mergesort, the strictness of the - * definition of "descending" is needed so that the call can safely - * reverse a descending sequence without violating stability. - * - * @param a the array in which a run is to be counted and possibly reversed - * @param lo index of the first element in the run - * @param hi index after the last element that may be contained in the run. - It is required that {@code lo < hi}. - * @param c the comparator to used for the sort - * @return the length of the run beginning at the specified position in - * the specified array - */ - private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator c) { - assert lo < hi; - int runHi = lo + 1; - if (runHi == hi) - return 1; - - // Find end of run, and reverse range if descending - if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending - while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0) - runHi++; - reverseRange(a, lo, runHi); - } else { // Ascending - while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0) - runHi++; - } - - return runHi - lo; - } - - /** - * Reverse the specified range of the specified array. - * - * @param a the array in which a range is to be reversed - * @param lo the index of the first element in the range to be reversed - * @param hi the index after the last element in the range to be reversed - */ - private void reverseRange(Buffer a, int lo, int hi) { - hi--; - while (lo < hi) { - s.swap(a, lo, hi); - lo++; - hi--; - } - } - - /** - * Returns the minimum acceptable run length for an array of the specified - * length. Natural runs shorter than this will be extended with - * {@link #binarySort}. - * - * Roughly speaking, the computation is: - * - * If n < MIN_MERGE, return n (it's too small to bother with fancy stuff). - * Else if n is an exact power of 2, return MIN_MERGE/2. - * Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k - * is close to, but strictly less than, an exact power of 2. - * - * For the rationale, see listsort.txt. - * - * @param n the length of the array to be sorted - * @return the length of the minimum run to be merged - */ - private int minRunLength(int n) { - assert n >= 0; - int r = 0; // Becomes 1 if any 1 bits are shifted off - while (n >= MIN_MERGE) { - r |= (n & 1); - n >>= 1; - } - return n + r; - } - - private class SortState { - - /** - * The Buffer being sorted. - */ - private final Buffer a; - - /** - * Length of the sort Buffer. - */ - private final int aLength; - - /** - * The comparator for this sort. - */ - private final Comparator c; - - /** - * When we get into galloping mode, we stay there until both runs win less - * often than MIN_GALLOP consecutive times. - */ - private static final int MIN_GALLOP = 7; - - /** - * This controls when we get *into* galloping mode. It is initialized - * to MIN_GALLOP. The mergeLo and mergeHi methods nudge it higher for - * random data, and lower for highly structured data. - */ - private int minGallop = MIN_GALLOP; - - /** - * Maximum initial size of tmp array, which is used for merging. The array - * can grow to accommodate demand. - * - * Unlike Tim's original C version, we do not allocate this much storage - * when sorting smaller arrays. This change was required for performance. - */ - private static final int INITIAL_TMP_STORAGE_LENGTH = 256; - - /** - * Temp storage for merges. - */ - private Buffer tmp; // Actual runtime type will be Object[], regardless of T - - /** - * Length of the temp storage. - */ - private int tmpLength = 0; - - /** - * A stack of pending runs yet to be merged. Run i starts at - * address base[i] and extends for len[i] elements. It's always - * true (so long as the indices are in bounds) that: - * - * runBase[i] + runLen[i] == runBase[i + 1] - * - * so we could cut the storage for this, but it's a minor amount, - * and keeping all the info explicit simplifies the code. - */ - private int stackSize = 0; // Number of pending runs on stack - private final int[] runBase; - private final int[] runLen; - - /** - * Creates a TimSort instance to maintain the state of an ongoing sort. - * - * @param a the array to be sorted - * @param c the comparator to determine the order of the sort - */ - private SortState(Buffer a, Comparator c, int len) { - this.aLength = len; - this.a = a; - this.c = c; - - // Allocate temp storage (which may be increased later if necessary) - tmpLength = len < 2 * INITIAL_TMP_STORAGE_LENGTH ? len >>> 1 : INITIAL_TMP_STORAGE_LENGTH; - tmp = s.allocate(tmpLength); - - /* - * Allocate runs-to-be-merged stack (which cannot be expanded). The - * stack length requirements are described in listsort.txt. The C - * version always uses the same stack length (85), but this was - * measured to be too expensive when sorting "mid-sized" arrays (e.g., - * 100 elements) in Java. Therefore, we use smaller (but sufficiently - * large) stack lengths for smaller arrays. The "magic numbers" in the - * computation below must be changed if MIN_MERGE is decreased. See - * the MIN_MERGE declaration above for more information. - */ - int stackLen = (len < 120 ? 5 : - len < 1542 ? 10 : - len < 119151 ? 19 : 40); - runBase = new int[stackLen]; - runLen = new int[stackLen]; - } - - /** - * Pushes the specified run onto the pending-run stack. - * - * @param runBase index of the first element in the run - * @param runLen the number of elements in the run - */ - private void pushRun(int runBase, int runLen) { - this.runBase[stackSize] = runBase; - this.runLen[stackSize] = runLen; - stackSize++; - } - - /** - * Examines the stack of runs waiting to be merged and merges adjacent runs - * until the stack invariants are reestablished: - * - * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1] - * 2. runLen[i - 2] > runLen[i - 1] - * - * This method is called each time a new run is pushed onto the stack, - * so the invariants are guaranteed to hold for i < stackSize upon - * entry to the method. - */ - private void mergeCollapse() { - while (stackSize > 1) { - int n = stackSize - 2; - if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) { - if (runLen[n - 1] < runLen[n + 1]) - n--; - mergeAt(n); - } else if (runLen[n] <= runLen[n + 1]) { - mergeAt(n); - } else { - break; // Invariant is established - } - } - } - - /** - * Merges all runs on the stack until only one remains. This method is - * called once, to complete the sort. - */ - private void mergeForceCollapse() { - while (stackSize > 1) { - int n = stackSize - 2; - if (n > 0 && runLen[n - 1] < runLen[n + 1]) - n--; - mergeAt(n); - } - } - - /** - * Merges the two runs at stack indices i and i+1. Run i must be - * the penultimate or antepenultimate run on the stack. In other words, - * i must be equal to stackSize-2 or stackSize-3. - * - * @param i stack index of the first of the two runs to merge - */ - private void mergeAt(int i) { - assert stackSize >= 2; - assert i >= 0; - assert i == stackSize - 2 || i == stackSize - 3; - - int base1 = runBase[i]; - int len1 = runLen[i]; - int base2 = runBase[i + 1]; - int len2 = runLen[i + 1]; - assert len1 > 0 && len2 > 0; - assert base1 + len1 == base2; - - /* - * Record the length of the combined runs; if i is the 3rd-last - * run now, also slide over the last run (which isn't involved - * in this merge). The current run (i+1) goes away in any case. - */ - runLen[i] = len1 + len2; - if (i == stackSize - 3) { - runBase[i + 1] = runBase[i + 2]; - runLen[i + 1] = runLen[i + 2]; - } - stackSize--; - - /* - * Find where the first element of run2 goes in run1. Prior elements - * in run1 can be ignored (because they're already in place). - */ - int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c); - assert k >= 0; - base1 += k; - len1 -= k; - if (len1 == 0) - return; - - /* - * Find where the last element of run1 goes in run2. Subsequent elements - * in run2 can be ignored (because they're already in place). - */ - len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c); - assert len2 >= 0; - if (len2 == 0) - return; - - // Merge remaining runs, using tmp array with min(len1, len2) elements - if (len1 <= len2) - mergeLo(base1, len1, base2, len2); - else - mergeHi(base1, len1, base2, len2); - } - - /** - * Locates the position at which to insert the specified key into the - * specified sorted range; if the range contains an element equal to key, - * returns the index of the leftmost equal element. - * - * @param key the key whose insertion point to search for - * @param a the array in which to search - * @param base the index of the first element in the range - * @param len the length of the range; must be > 0 - * @param hint the index at which to begin the search, 0 <= hint < n. - * The closer hint is to the result, the faster this method will run. - * @param c the comparator used to order the range, and to search - * @return the int k, 0 <= k <= n such that a[b + k - 1] < key <= a[b + k], - * pretending that a[b - 1] is minus infinity and a[b + n] is infinity. - * In other words, key belongs at index b + k; or in other words, - * the first k elements of a should precede key, and the last n - k - * should follow it. - */ - private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator c) { - assert len > 0 && hint >= 0 && hint < len; - int lastOfs = 0; - int ofs = 1; - if (c.compare(key, s.getKey(a, base + hint)) > 0) { - // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs] - int maxOfs = len - hint; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) { - lastOfs = ofs; - ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow - ofs = maxOfs; - } - if (ofs > maxOfs) - ofs = maxOfs; - - // Make offsets relative to base - lastOfs += hint; - ofs += hint; - } else { // key <= a[base + hint] - // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs] - final int maxOfs = hint + 1; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) { - lastOfs = ofs; - ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow - ofs = maxOfs; - } - if (ofs > maxOfs) - ofs = maxOfs; - - // Make offsets relative to base - int tmp = lastOfs; - lastOfs = hint - ofs; - ofs = hint - tmp; - } - assert -1 <= lastOfs && lastOfs < ofs && ofs <= len; - - /* - * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere - * to the right of lastOfs but no farther right than ofs. Do a binary - * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs]. - */ - lastOfs++; - while (lastOfs < ofs) { - int m = lastOfs + ((ofs - lastOfs) >>> 1); - - if (c.compare(key, s.getKey(a, base + m)) > 0) - lastOfs = m + 1; // a[base + m] < key - else - ofs = m; // key <= a[base + m] - } - assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs] - return ofs; - } - - /** - * Like gallopLeft, except that if the range contains an element equal to - * key, gallopRight returns the index after the rightmost equal element. - * - * @param key the key whose insertion point to search for - * @param a the array in which to search - * @param base the index of the first element in the range - * @param len the length of the range; must be > 0 - * @param hint the index at which to begin the search, 0 <= hint < n. - * The closer hint is to the result, the faster this method will run. - * @param c the comparator used to order the range, and to search - * @return the int k, 0 <= k <= n such that a[b + k - 1] <= key < a[b + k] - */ - private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator c) { - assert len > 0 && hint >= 0 && hint < len; - - int ofs = 1; - int lastOfs = 0; - if (c.compare(key, s.getKey(a, base + hint)) < 0) { - // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs] - int maxOfs = hint + 1; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) { - lastOfs = ofs; - ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow - ofs = maxOfs; - } - if (ofs > maxOfs) - ofs = maxOfs; - - // Make offsets relative to b - int tmp = lastOfs; - lastOfs = hint - ofs; - ofs = hint - tmp; - } else { // a[b + hint] <= key - // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs] - int maxOfs = len - hint; - while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) { - lastOfs = ofs; - ofs = (ofs << 1) + 1; - if (ofs <= 0) // int overflow - ofs = maxOfs; - } - if (ofs > maxOfs) - ofs = maxOfs; - - // Make offsets relative to b - lastOfs += hint; - ofs += hint; - } - assert -1 <= lastOfs && lastOfs < ofs && ofs <= len; - - /* - * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to - * the right of lastOfs but no farther right than ofs. Do a binary - * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs]. - */ - lastOfs++; - while (lastOfs < ofs) { - int m = lastOfs + ((ofs - lastOfs) >>> 1); - - if (c.compare(key, s.getKey(a, base + m)) < 0) - ofs = m; // key < a[b + m] - else - lastOfs = m + 1; // a[b + m] <= key - } - assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs] - return ofs; - } - - /** - * Merges two adjacent runs in place, in a stable fashion. The first - * element of the first run must be greater than the first element of the - * second run (a[base1] > a[base2]), and the last element of the first run - * (a[base1 + len1-1]) must be greater than all elements of the second run. - * - * For performance, this method should be called only when len1 <= len2; - * its twin, mergeHi should be called if len1 >= len2. (Either method - * may be called if len1 == len2.) - * - * @param base1 index of first element in first run to be merged - * @param len1 length of first run to be merged (must be > 0) - * @param base2 index of first element in second run to be merged - * (must be aBase + aLen) - * @param len2 length of second run to be merged (must be > 0) - */ - private void mergeLo(int base1, int len1, int base2, int len2) { - assert len1 > 0 && len2 > 0 && base1 + len1 == base2; - - // Copy first run into temp array - Buffer a = this.a; // For performance - Buffer tmp = ensureCapacity(len1); - s.copyRange(a, base1, tmp, 0, len1); - - int cursor1 = 0; // Indexes into tmp array - int cursor2 = base2; // Indexes int a - int dest = base1; // Indexes int a - - // Move first element of second run and deal with degenerate cases - s.copyElement(a, cursor2++, a, dest++); - if (--len2 == 0) { - s.copyRange(tmp, cursor1, a, dest, len1); - return; - } - if (len1 == 1) { - s.copyRange(a, cursor2, a, dest, len2); - s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge - return; - } - - Comparator c = this.c; // Use local variable for performance - int minGallop = this.minGallop; // " " " " " - outer: - while (true) { - int count1 = 0; // Number of times in a row that first run won - int count2 = 0; // Number of times in a row that second run won - - /* - * Do the straightforward thing until (if ever) one run starts - * winning consistently. - */ - do { - assert len1 > 1 && len2 > 0; - if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) { - s.copyElement(a, cursor2++, a, dest++); - count2++; - count1 = 0; - if (--len2 == 0) - break outer; - } else { - s.copyElement(tmp, cursor1++, a, dest++); - count1++; - count2 = 0; - if (--len1 == 1) - break outer; - } - } while ((count1 | count2) < minGallop); - - /* - * One run is winning so consistently that galloping may be a - * huge win. So try that, and continue galloping until (if ever) - * neither run appears to be winning consistently anymore. - */ - do { - assert len1 > 1 && len2 > 0; - count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c); - if (count1 != 0) { - s.copyRange(tmp, cursor1, a, dest, count1); - dest += count1; - cursor1 += count1; - len1 -= count1; - if (len1 <= 1) // len1 == 1 || len1 == 0 - break outer; - } - s.copyElement(a, cursor2++, a, dest++); - if (--len2 == 0) - break outer; - - count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c); - if (count2 != 0) { - s.copyRange(a, cursor2, a, dest, count2); - dest += count2; - cursor2 += count2; - len2 -= count2; - if (len2 == 0) - break outer; - } - s.copyElement(tmp, cursor1++, a, dest++); - if (--len1 == 1) - break outer; - minGallop--; - } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) - minGallop = 0; - minGallop += 2; // Penalize for leaving gallop mode - } // End of "outer" loop - this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field - - if (len1 == 1) { - assert len2 > 0; - s.copyRange(a, cursor2, a, dest, len2); - s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge - } else if (len1 == 0) { - throw new IllegalArgumentException( - "Comparison method violates its general contract!"); - } else { - assert len2 == 0; - assert len1 > 1; - s.copyRange(tmp, cursor1, a, dest, len1); - } - } - - /** - * Like mergeLo, except that this method should be called only if - * len1 >= len2; mergeLo should be called if len1 <= len2. (Either method - * may be called if len1 == len2.) - * - * @param base1 index of first element in first run to be merged - * @param len1 length of first run to be merged (must be > 0) - * @param base2 index of first element in second run to be merged - * (must be aBase + aLen) - * @param len2 length of second run to be merged (must be > 0) - */ - private void mergeHi(int base1, int len1, int base2, int len2) { - assert len1 > 0 && len2 > 0 && base1 + len1 == base2; - - // Copy second run into temp array - Buffer a = this.a; // For performance - Buffer tmp = ensureCapacity(len2); - s.copyRange(a, base2, tmp, 0, len2); - - int cursor1 = base1 + len1 - 1; // Indexes into a - int cursor2 = len2 - 1; // Indexes into tmp array - int dest = base2 + len2 - 1; // Indexes into a - - // Move last element of first run and deal with degenerate cases - s.copyElement(a, cursor1--, a, dest--); - if (--len1 == 0) { - s.copyRange(tmp, 0, a, dest - (len2 - 1), len2); - return; - } - if (len2 == 1) { - dest -= len1; - cursor1 -= len1; - s.copyRange(a, cursor1 + 1, a, dest + 1, len1); - s.copyElement(tmp, cursor2, a, dest); - return; - } - - Comparator c = this.c; // Use local variable for performance - int minGallop = this.minGallop; // " " " " " - outer: - while (true) { - int count1 = 0; // Number of times in a row that first run won - int count2 = 0; // Number of times in a row that second run won - - /* - * Do the straightforward thing until (if ever) one run - * appears to win consistently. - */ - do { - assert len1 > 0 && len2 > 1; - if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) { - s.copyElement(a, cursor1--, a, dest--); - count1++; - count2 = 0; - if (--len1 == 0) - break outer; - } else { - s.copyElement(tmp, cursor2--, a, dest--); - count2++; - count1 = 0; - if (--len2 == 1) - break outer; - } - } while ((count1 | count2) < minGallop); - - /* - * One run is winning so consistently that galloping may be a - * huge win. So try that, and continue galloping until (if ever) - * neither run appears to be winning consistently anymore. - */ - do { - assert len1 > 0 && len2 > 1; - count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c); - if (count1 != 0) { - dest -= count1; - cursor1 -= count1; - len1 -= count1; - s.copyRange(a, cursor1 + 1, a, dest + 1, count1); - if (len1 == 0) - break outer; - } - s.copyElement(tmp, cursor2--, a, dest--); - if (--len2 == 1) - break outer; - - count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c); - if (count2 != 0) { - dest -= count2; - cursor2 -= count2; - len2 -= count2; - s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2); - if (len2 <= 1) // len2 == 1 || len2 == 0 - break outer; - } - s.copyElement(a, cursor1--, a, dest--); - if (--len1 == 0) - break outer; - minGallop--; - } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); - if (minGallop < 0) - minGallop = 0; - minGallop += 2; // Penalize for leaving gallop mode - } // End of "outer" loop - this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field - - if (len2 == 1) { - assert len1 > 0; - dest -= len1; - cursor1 -= len1; - s.copyRange(a, cursor1 + 1, a, dest + 1, len1); - s.copyElement(tmp, cursor2, a, dest); // Move first elt of run2 to front of merge - } else if (len2 == 0) { - throw new IllegalArgumentException( - "Comparison method violates its general contract!"); - } else { - assert len1 == 0; - assert len2 > 0; - s.copyRange(tmp, 0, a, dest - (len2 - 1), len2); - } - } - - /** - * Ensures that the external array tmp has at least the specified - * number of elements, increasing its size if necessary. The size - * increases exponentially to ensure amortized linear time complexity. - * - * @param minCapacity the minimum required capacity of the tmp array - * @return tmp, whether or not it grew - */ - private Buffer ensureCapacity(int minCapacity) { - if (tmpLength < minCapacity) { - // Compute smallest power of 2 > minCapacity - int newSize = minCapacity; - newSize |= newSize >> 1; - newSize |= newSize >> 2; - newSize |= newSize >> 4; - newSize |= newSize >> 8; - newSize |= newSize >> 16; - newSize++; - - if (newSize < 0) // Not bloody likely! - newSize = minCapacity; - else - newSize = Math.min(newSize, aLength >>> 1); - - tmp = s.allocate(newSize); - tmpLength = newSize; - } - return tmp; - } - } -} diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java new file mode 100644 index 0000000000000..409e1a41c5d49 --- /dev/null +++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java @@ -0,0 +1,940 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection; + +import java.util.Comparator; + +/** + * A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort." + * See the method comment on sort() for more details. + * + * This has been kept in Java with the original style in order to match very closely with the + * Android source code, and thus be easy to verify correctness. The class is package private. We put + * a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to + * package org.apache.spark. + * + * The purpose of the port is to generalize the interface to the sort to accept input data formats + * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap + * uses this to sort an Array with alternating elements of the form [key, value, key, value]. + * This generalization comes with minimal overhead -- see SortDataFormat for more information. + * + * We allow key reuse to prevent creating many key objects -- see SortDataFormat. + * + * @see org.apache.spark.util.collection.SortDataFormat + * @see org.apache.spark.util.collection.Sorter + */ +class TimSort { + + /** + * This is the minimum sized sequence that will be merged. Shorter + * sequences will be lengthened by calling binarySort. If the entire + * array is less than this length, no merges will be performed. + * + * This constant should be a power of two. It was 64 in Tim Peter's C + * implementation, but 32 was empirically determined to work better in + * this implementation. In the unlikely event that you set this constant + * to be a number that's not a power of two, you'll need to change the + * minRunLength computation. + * + * If you decrease this constant, you must change the stackLen + * computation in the TimSort constructor, or you risk an + * ArrayOutOfBounds exception. See listsort.txt for a discussion + * of the minimum stack length required as a function of the length + * of the array being sorted and the minimum merge sequence length. + */ + private static final int MIN_MERGE = 32; + + private final SortDataFormat s; + + public TimSort(SortDataFormat sortDataFormat) { + this.s = sortDataFormat; + } + + /** + * A stable, adaptive, iterative mergesort that requires far fewer than + * n lg(n) comparisons when running on partially sorted arrays, while + * offering performance comparable to a traditional mergesort when run + * on random arrays. Like all proper mergesorts, this sort is stable and + * runs O(n log n) time (worst case). In the worst case, this sort requires + * temporary storage space for n/2 object references; in the best case, + * it requires only a small constant amount of space. + * + * This implementation was adapted from Tim Peters's list sort for + * Python, which is described in detail here: + * + * http://svn.python.org/projects/python/trunk/Objects/listsort.txt + * + * Tim's C code may be found here: + * + * http://svn.python.org/projects/python/trunk/Objects/listobject.c + * + * The underlying techniques are described in this paper (and may have + * even earlier origins): + * + * "Optimistic Sorting and Information Theoretic Complexity" + * Peter McIlroy + * SODA (Fourth Annual ACM-SIAM Symposium on Discrete Algorithms), + * pp 467-474, Austin, Texas, 25-27 January 1993. + * + * While the API to this class consists solely of static methods, it is + * (privately) instantiable; a TimSort instance holds the state of an ongoing + * sort, assuming the input array is large enough to warrant the full-blown + * TimSort. Small arrays are sorted in place, using a binary insertion sort. + * + * @author Josh Bloch + */ + public void sort(Buffer a, int lo, int hi, Comparator c) { + assert c != null; + + int nRemaining = hi - lo; + if (nRemaining < 2) + return; // Arrays of size 0 and 1 are always sorted + + // If array is small, do a "mini-TimSort" with no merges + if (nRemaining < MIN_MERGE) { + int initRunLen = countRunAndMakeAscending(a, lo, hi, c); + binarySort(a, lo, hi, lo + initRunLen, c); + return; + } + + /** + * March over the array once, left to right, finding natural runs, + * extending short natural runs to minRun elements, and merging runs + * to maintain stack invariant. + */ + SortState sortState = new SortState(a, c, hi - lo); + int minRun = minRunLength(nRemaining); + do { + // Identify next run + int runLen = countRunAndMakeAscending(a, lo, hi, c); + + // If run is short, extend to min(minRun, nRemaining) + if (runLen < minRun) { + int force = nRemaining <= minRun ? nRemaining : minRun; + binarySort(a, lo, lo + force, lo + runLen, c); + runLen = force; + } + + // Push run onto pending-run stack, and maybe merge + sortState.pushRun(lo, runLen); + sortState.mergeCollapse(); + + // Advance to find next run + lo += runLen; + nRemaining -= runLen; + } while (nRemaining != 0); + + // Merge all remaining runs to complete sort + assert lo == hi; + sortState.mergeForceCollapse(); + assert sortState.stackSize == 1; + } + + /** + * Sorts the specified portion of the specified array using a binary + * insertion sort. This is the best method for sorting small numbers + * of elements. It requires O(n log n) compares, but O(n^2) data + * movement (worst case). + * + * If the initial part of the specified range is already sorted, + * this method can take advantage of it: the method assumes that the + * elements from index {@code lo}, inclusive, to {@code start}, + * exclusive are already sorted. + * + * @param a the array in which a range is to be sorted + * @param lo the index of the first element in the range to be sorted + * @param hi the index after the last element in the range to be sorted + * @param start the index of the first element in the range that is + * not already known to be sorted ({@code lo <= start <= hi}) + * @param c comparator to used for the sort + */ + @SuppressWarnings("fallthrough") + private void binarySort(Buffer a, int lo, int hi, int start, Comparator c) { + assert lo <= start && start <= hi; + if (start == lo) + start++; + + K key0 = s.newKey(); + K key1 = s.newKey(); + + Buffer pivotStore = s.allocate(1); + for ( ; start < hi; start++) { + s.copyElement(a, start, pivotStore, 0); + K pivot = s.getKey(pivotStore, 0, key0); + + // Set left (and right) to the index where a[start] (pivot) belongs + int left = lo; + int right = start; + assert left <= right; + /* + * Invariants: + * pivot >= all in [lo, left). + * pivot < all in [right, start). + */ + while (left < right) { + int mid = (left + right) >>> 1; + if (c.compare(pivot, s.getKey(a, mid, key1)) < 0) + right = mid; + else + left = mid + 1; + } + assert left == right; + + /* + * The invariants still hold: pivot >= all in [lo, left) and + * pivot < all in [left, start), so pivot belongs at left. Note + * that if there are elements equal to pivot, left points to the + * first slot after them -- that's why this sort is stable. + * Slide elements over to make room for pivot. + */ + int n = start - left; // The number of elements to move + // Switch is just an optimization for arraycopy in default case + switch (n) { + case 2: s.copyElement(a, left + 1, a, left + 2); + case 1: s.copyElement(a, left, a, left + 1); + break; + default: s.copyRange(a, left, a, left + 1, n); + } + s.copyElement(pivotStore, 0, a, left); + } + } + + /** + * Returns the length of the run beginning at the specified position in + * the specified array and reverses the run if it is descending (ensuring + * that the run will always be ascending when the method returns). + * + * A run is the longest ascending sequence with: + * + * a[lo] <= a[lo + 1] <= a[lo + 2] <= ... + * + * or the longest descending sequence with: + * + * a[lo] > a[lo + 1] > a[lo + 2] > ... + * + * For its intended use in a stable mergesort, the strictness of the + * definition of "descending" is needed so that the call can safely + * reverse a descending sequence without violating stability. + * + * @param a the array in which a run is to be counted and possibly reversed + * @param lo index of the first element in the run + * @param hi index after the last element that may be contained in the run. + It is required that {@code lo < hi}. + * @param c the comparator to used for the sort + * @return the length of the run beginning at the specified position in + * the specified array + */ + private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator c) { + assert lo < hi; + int runHi = lo + 1; + if (runHi == hi) + return 1; + + K key0 = s.newKey(); + K key1 = s.newKey(); + + // Find end of run, and reverse range if descending + if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0) + runHi++; + reverseRange(a, lo, runHi); + } else { // Ascending + while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0) + runHi++; + } + + return runHi - lo; + } + + /** + * Reverse the specified range of the specified array. + * + * @param a the array in which a range is to be reversed + * @param lo the index of the first element in the range to be reversed + * @param hi the index after the last element in the range to be reversed + */ + private void reverseRange(Buffer a, int lo, int hi) { + hi--; + while (lo < hi) { + s.swap(a, lo, hi); + lo++; + hi--; + } + } + + /** + * Returns the minimum acceptable run length for an array of the specified + * length. Natural runs shorter than this will be extended with + * {@link #binarySort}. + * + * Roughly speaking, the computation is: + * + * If n < MIN_MERGE, return n (it's too small to bother with fancy stuff). + * Else if n is an exact power of 2, return MIN_MERGE/2. + * Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k + * is close to, but strictly less than, an exact power of 2. + * + * For the rationale, see listsort.txt. + * + * @param n the length of the array to be sorted + * @return the length of the minimum run to be merged + */ + private int minRunLength(int n) { + assert n >= 0; + int r = 0; // Becomes 1 if any 1 bits are shifted off + while (n >= MIN_MERGE) { + r |= (n & 1); + n >>= 1; + } + return n + r; + } + + private class SortState { + + /** + * The Buffer being sorted. + */ + private final Buffer a; + + /** + * Length of the sort Buffer. + */ + private final int aLength; + + /** + * The comparator for this sort. + */ + private final Comparator c; + + /** + * When we get into galloping mode, we stay there until both runs win less + * often than MIN_GALLOP consecutive times. + */ + private static final int MIN_GALLOP = 7; + + /** + * This controls when we get *into* galloping mode. It is initialized + * to MIN_GALLOP. The mergeLo and mergeHi methods nudge it higher for + * random data, and lower for highly structured data. + */ + private int minGallop = MIN_GALLOP; + + /** + * Maximum initial size of tmp array, which is used for merging. The array + * can grow to accommodate demand. + * + * Unlike Tim's original C version, we do not allocate this much storage + * when sorting smaller arrays. This change was required for performance. + */ + private static final int INITIAL_TMP_STORAGE_LENGTH = 256; + + /** + * Temp storage for merges. + */ + private Buffer tmp; // Actual runtime type will be Object[], regardless of T + + /** + * Length of the temp storage. + */ + private int tmpLength = 0; + + /** + * A stack of pending runs yet to be merged. Run i starts at + * address base[i] and extends for len[i] elements. It's always + * true (so long as the indices are in bounds) that: + * + * runBase[i] + runLen[i] == runBase[i + 1] + * + * so we could cut the storage for this, but it's a minor amount, + * and keeping all the info explicit simplifies the code. + */ + private int stackSize = 0; // Number of pending runs on stack + private final int[] runBase; + private final int[] runLen; + + /** + * Creates a TimSort instance to maintain the state of an ongoing sort. + * + * @param a the array to be sorted + * @param c the comparator to determine the order of the sort + */ + private SortState(Buffer a, Comparator c, int len) { + this.aLength = len; + this.a = a; + this.c = c; + + // Allocate temp storage (which may be increased later if necessary) + tmpLength = len < 2 * INITIAL_TMP_STORAGE_LENGTH ? len >>> 1 : INITIAL_TMP_STORAGE_LENGTH; + tmp = s.allocate(tmpLength); + + /* + * Allocate runs-to-be-merged stack (which cannot be expanded). The + * stack length requirements are described in listsort.txt. The C + * version always uses the same stack length (85), but this was + * measured to be too expensive when sorting "mid-sized" arrays (e.g., + * 100 elements) in Java. Therefore, we use smaller (but sufficiently + * large) stack lengths for smaller arrays. The "magic numbers" in the + * computation below must be changed if MIN_MERGE is decreased. See + * the MIN_MERGE declaration above for more information. + */ + int stackLen = (len < 120 ? 5 : + len < 1542 ? 10 : + len < 119151 ? 19 : 40); + runBase = new int[stackLen]; + runLen = new int[stackLen]; + } + + /** + * Pushes the specified run onto the pending-run stack. + * + * @param runBase index of the first element in the run + * @param runLen the number of elements in the run + */ + private void pushRun(int runBase, int runLen) { + this.runBase[stackSize] = runBase; + this.runLen[stackSize] = runLen; + stackSize++; + } + + /** + * Examines the stack of runs waiting to be merged and merges adjacent runs + * until the stack invariants are reestablished: + * + * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1] + * 2. runLen[i - 2] > runLen[i - 1] + * + * This method is called each time a new run is pushed onto the stack, + * so the invariants are guaranteed to hold for i < stackSize upon + * entry to the method. + */ + private void mergeCollapse() { + while (stackSize > 1) { + int n = stackSize - 2; + if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) { + if (runLen[n - 1] < runLen[n + 1]) + n--; + mergeAt(n); + } else if (runLen[n] <= runLen[n + 1]) { + mergeAt(n); + } else { + break; // Invariant is established + } + } + } + + /** + * Merges all runs on the stack until only one remains. This method is + * called once, to complete the sort. + */ + private void mergeForceCollapse() { + while (stackSize > 1) { + int n = stackSize - 2; + if (n > 0 && runLen[n - 1] < runLen[n + 1]) + n--; + mergeAt(n); + } + } + + /** + * Merges the two runs at stack indices i and i+1. Run i must be + * the penultimate or antepenultimate run on the stack. In other words, + * i must be equal to stackSize-2 or stackSize-3. + * + * @param i stack index of the first of the two runs to merge + */ + private void mergeAt(int i) { + assert stackSize >= 2; + assert i >= 0; + assert i == stackSize - 2 || i == stackSize - 3; + + int base1 = runBase[i]; + int len1 = runLen[i]; + int base2 = runBase[i + 1]; + int len2 = runLen[i + 1]; + assert len1 > 0 && len2 > 0; + assert base1 + len1 == base2; + + /* + * Record the length of the combined runs; if i is the 3rd-last + * run now, also slide over the last run (which isn't involved + * in this merge). The current run (i+1) goes away in any case. + */ + runLen[i] = len1 + len2; + if (i == stackSize - 3) { + runBase[i + 1] = runBase[i + 2]; + runLen[i + 1] = runLen[i + 2]; + } + stackSize--; + + K key0 = s.newKey(); + + /* + * Find where the first element of run2 goes in run1. Prior elements + * in run1 can be ignored (because they're already in place). + */ + int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c); + assert k >= 0; + base1 += k; + len1 -= k; + if (len1 == 0) + return; + + /* + * Find where the last element of run1 goes in run2. Subsequent elements + * in run2 can be ignored (because they're already in place). + */ + len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c); + assert len2 >= 0; + if (len2 == 0) + return; + + // Merge remaining runs, using tmp array with min(len1, len2) elements + if (len1 <= len2) + mergeLo(base1, len1, base2, len2); + else + mergeHi(base1, len1, base2, len2); + } + + /** + * Locates the position at which to insert the specified key into the + * specified sorted range; if the range contains an element equal to key, + * returns the index of the leftmost equal element. + * + * @param key the key whose insertion point to search for + * @param a the array in which to search + * @param base the index of the first element in the range + * @param len the length of the range; must be > 0 + * @param hint the index at which to begin the search, 0 <= hint < n. + * The closer hint is to the result, the faster this method will run. + * @param c the comparator used to order the range, and to search + * @return the int k, 0 <= k <= n such that a[b + k - 1] < key <= a[b + k], + * pretending that a[b - 1] is minus infinity and a[b + n] is infinity. + * In other words, key belongs at index b + k; or in other words, + * the first k elements of a should precede key, and the last n - k + * should follow it. + */ + private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator c) { + assert len > 0 && hint >= 0 && hint < len; + int lastOfs = 0; + int ofs = 1; + K key0 = s.newKey(); + + if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) { + // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs] + int maxOfs = len - hint; + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) { + lastOfs = ofs; + ofs = (ofs << 1) + 1; + if (ofs <= 0) // int overflow + ofs = maxOfs; + } + if (ofs > maxOfs) + ofs = maxOfs; + + // Make offsets relative to base + lastOfs += hint; + ofs += hint; + } else { // key <= a[base + hint] + // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs] + final int maxOfs = hint + 1; + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) { + lastOfs = ofs; + ofs = (ofs << 1) + 1; + if (ofs <= 0) // int overflow + ofs = maxOfs; + } + if (ofs > maxOfs) + ofs = maxOfs; + + // Make offsets relative to base + int tmp = lastOfs; + lastOfs = hint - ofs; + ofs = hint - tmp; + } + assert -1 <= lastOfs && lastOfs < ofs && ofs <= len; + + /* + * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere + * to the right of lastOfs but no farther right than ofs. Do a binary + * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs]. + */ + lastOfs++; + while (lastOfs < ofs) { + int m = lastOfs + ((ofs - lastOfs) >>> 1); + + if (c.compare(key, s.getKey(a, base + m, key0)) > 0) + lastOfs = m + 1; // a[base + m] < key + else + ofs = m; // key <= a[base + m] + } + assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs] + return ofs; + } + + /** + * Like gallopLeft, except that if the range contains an element equal to + * key, gallopRight returns the index after the rightmost equal element. + * + * @param key the key whose insertion point to search for + * @param a the array in which to search + * @param base the index of the first element in the range + * @param len the length of the range; must be > 0 + * @param hint the index at which to begin the search, 0 <= hint < n. + * The closer hint is to the result, the faster this method will run. + * @param c the comparator used to order the range, and to search + * @return the int k, 0 <= k <= n such that a[b + k - 1] <= key < a[b + k] + */ + private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator c) { + assert len > 0 && hint >= 0 && hint < len; + + int ofs = 1; + int lastOfs = 0; + K key1 = s.newKey(); + + if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) { + // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs] + int maxOfs = hint + 1; + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) { + lastOfs = ofs; + ofs = (ofs << 1) + 1; + if (ofs <= 0) // int overflow + ofs = maxOfs; + } + if (ofs > maxOfs) + ofs = maxOfs; + + // Make offsets relative to b + int tmp = lastOfs; + lastOfs = hint - ofs; + ofs = hint - tmp; + } else { // a[b + hint] <= key + // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs] + int maxOfs = len - hint; + while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) { + lastOfs = ofs; + ofs = (ofs << 1) + 1; + if (ofs <= 0) // int overflow + ofs = maxOfs; + } + if (ofs > maxOfs) + ofs = maxOfs; + + // Make offsets relative to b + lastOfs += hint; + ofs += hint; + } + assert -1 <= lastOfs && lastOfs < ofs && ofs <= len; + + /* + * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to + * the right of lastOfs but no farther right than ofs. Do a binary + * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs]. + */ + lastOfs++; + while (lastOfs < ofs) { + int m = lastOfs + ((ofs - lastOfs) >>> 1); + + if (c.compare(key, s.getKey(a, base + m, key1)) < 0) + ofs = m; // key < a[b + m] + else + lastOfs = m + 1; // a[b + m] <= key + } + assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs] + return ofs; + } + + /** + * Merges two adjacent runs in place, in a stable fashion. The first + * element of the first run must be greater than the first element of the + * second run (a[base1] > a[base2]), and the last element of the first run + * (a[base1 + len1-1]) must be greater than all elements of the second run. + * + * For performance, this method should be called only when len1 <= len2; + * its twin, mergeHi should be called if len1 >= len2. (Either method + * may be called if len1 == len2.) + * + * @param base1 index of first element in first run to be merged + * @param len1 length of first run to be merged (must be > 0) + * @param base2 index of first element in second run to be merged + * (must be aBase + aLen) + * @param len2 length of second run to be merged (must be > 0) + */ + private void mergeLo(int base1, int len1, int base2, int len2) { + assert len1 > 0 && len2 > 0 && base1 + len1 == base2; + + // Copy first run into temp array + Buffer a = this.a; // For performance + Buffer tmp = ensureCapacity(len1); + s.copyRange(a, base1, tmp, 0, len1); + + int cursor1 = 0; // Indexes into tmp array + int cursor2 = base2; // Indexes int a + int dest = base1; // Indexes int a + + // Move first element of second run and deal with degenerate cases + s.copyElement(a, cursor2++, a, dest++); + if (--len2 == 0) { + s.copyRange(tmp, cursor1, a, dest, len1); + return; + } + if (len1 == 1) { + s.copyRange(a, cursor2, a, dest, len2); + s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge + return; + } + + K key0 = s.newKey(); + K key1 = s.newKey(); + + Comparator c = this.c; // Use local variable for performance + int minGallop = this.minGallop; // " " " " " + outer: + while (true) { + int count1 = 0; // Number of times in a row that first run won + int count2 = 0; // Number of times in a row that second run won + + /* + * Do the straightforward thing until (if ever) one run starts + * winning consistently. + */ + do { + assert len1 > 1 && len2 > 0; + if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) { + s.copyElement(a, cursor2++, a, dest++); + count2++; + count1 = 0; + if (--len2 == 0) + break outer; + } else { + s.copyElement(tmp, cursor1++, a, dest++); + count1++; + count2 = 0; + if (--len1 == 1) + break outer; + } + } while ((count1 | count2) < minGallop); + + /* + * One run is winning so consistently that galloping may be a + * huge win. So try that, and continue galloping until (if ever) + * neither run appears to be winning consistently anymore. + */ + do { + assert len1 > 1 && len2 > 0; + count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c); + if (count1 != 0) { + s.copyRange(tmp, cursor1, a, dest, count1); + dest += count1; + cursor1 += count1; + len1 -= count1; + if (len1 <= 1) // len1 == 1 || len1 == 0 + break outer; + } + s.copyElement(a, cursor2++, a, dest++); + if (--len2 == 0) + break outer; + + count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c); + if (count2 != 0) { + s.copyRange(a, cursor2, a, dest, count2); + dest += count2; + cursor2 += count2; + len2 -= count2; + if (len2 == 0) + break outer; + } + s.copyElement(tmp, cursor1++, a, dest++); + if (--len1 == 1) + break outer; + minGallop--; + } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); + if (minGallop < 0) + minGallop = 0; + minGallop += 2; // Penalize for leaving gallop mode + } // End of "outer" loop + this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field + + if (len1 == 1) { + assert len2 > 0; + s.copyRange(a, cursor2, a, dest, len2); + s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge + } else if (len1 == 0) { + throw new IllegalArgumentException( + "Comparison method violates its general contract!"); + } else { + assert len2 == 0; + assert len1 > 1; + s.copyRange(tmp, cursor1, a, dest, len1); + } + } + + /** + * Like mergeLo, except that this method should be called only if + * len1 >= len2; mergeLo should be called if len1 <= len2. (Either method + * may be called if len1 == len2.) + * + * @param base1 index of first element in first run to be merged + * @param len1 length of first run to be merged (must be > 0) + * @param base2 index of first element in second run to be merged + * (must be aBase + aLen) + * @param len2 length of second run to be merged (must be > 0) + */ + private void mergeHi(int base1, int len1, int base2, int len2) { + assert len1 > 0 && len2 > 0 && base1 + len1 == base2; + + // Copy second run into temp array + Buffer a = this.a; // For performance + Buffer tmp = ensureCapacity(len2); + s.copyRange(a, base2, tmp, 0, len2); + + int cursor1 = base1 + len1 - 1; // Indexes into a + int cursor2 = len2 - 1; // Indexes into tmp array + int dest = base2 + len2 - 1; // Indexes into a + + K key0 = s.newKey(); + K key1 = s.newKey(); + + // Move last element of first run and deal with degenerate cases + s.copyElement(a, cursor1--, a, dest--); + if (--len1 == 0) { + s.copyRange(tmp, 0, a, dest - (len2 - 1), len2); + return; + } + if (len2 == 1) { + dest -= len1; + cursor1 -= len1; + s.copyRange(a, cursor1 + 1, a, dest + 1, len1); + s.copyElement(tmp, cursor2, a, dest); + return; + } + + Comparator c = this.c; // Use local variable for performance + int minGallop = this.minGallop; // " " " " " + outer: + while (true) { + int count1 = 0; // Number of times in a row that first run won + int count2 = 0; // Number of times in a row that second run won + + /* + * Do the straightforward thing until (if ever) one run + * appears to win consistently. + */ + do { + assert len1 > 0 && len2 > 1; + if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) { + s.copyElement(a, cursor1--, a, dest--); + count1++; + count2 = 0; + if (--len1 == 0) + break outer; + } else { + s.copyElement(tmp, cursor2--, a, dest--); + count2++; + count1 = 0; + if (--len2 == 1) + break outer; + } + } while ((count1 | count2) < minGallop); + + /* + * One run is winning so consistently that galloping may be a + * huge win. So try that, and continue galloping until (if ever) + * neither run appears to be winning consistently anymore. + */ + do { + assert len1 > 0 && len2 > 1; + count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c); + if (count1 != 0) { + dest -= count1; + cursor1 -= count1; + len1 -= count1; + s.copyRange(a, cursor1 + 1, a, dest + 1, count1); + if (len1 == 0) + break outer; + } + s.copyElement(tmp, cursor2--, a, dest--); + if (--len2 == 1) + break outer; + + count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c); + if (count2 != 0) { + dest -= count2; + cursor2 -= count2; + len2 -= count2; + s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2); + if (len2 <= 1) // len2 == 1 || len2 == 0 + break outer; + } + s.copyElement(a, cursor1--, a, dest--); + if (--len1 == 0) + break outer; + minGallop--; + } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP); + if (minGallop < 0) + minGallop = 0; + minGallop += 2; // Penalize for leaving gallop mode + } // End of "outer" loop + this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field + + if (len2 == 1) { + assert len1 > 0; + dest -= len1; + cursor1 -= len1; + s.copyRange(a, cursor1 + 1, a, dest + 1, len1); + s.copyElement(tmp, cursor2, a, dest); // Move first elt of run2 to front of merge + } else if (len2 == 0) { + throw new IllegalArgumentException( + "Comparison method violates its general contract!"); + } else { + assert len1 == 0; + assert len2 > 0; + s.copyRange(tmp, 0, a, dest - (len2 - 1), len2); + } + } + + /** + * Ensures that the external array tmp has at least the specified + * number of elements, increasing its size if necessary. The size + * increases exponentially to ensure amortized linear time complexity. + * + * @param minCapacity the minimum required capacity of the tmp array + * @return tmp, whether or not it grew + */ + private Buffer ensureCapacity(int minCapacity) { + if (tmpLength < minCapacity) { + // Compute smallest power of 2 > minCapacity + int newSize = minCapacity; + newSize |= newSize >> 1; + newSize |= newSize >> 2; + newSize |= newSize >> 4; + newSize |= newSize >> 8; + newSize |= newSize >> 16; + newSize++; + + if (newSize < 0) // Not bloody likely! + newSize = minCapacity; + else + newSize = Math.min(newSize, aLength >>> 1); + + tmp = s.allocate(newSize); + tmpLength = newSize; + } + return tmp; + } + } +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js new file mode 100644 index 0000000000000..14ba37d7c9bd9 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Register functions to show/hide columns based on checkboxes. These need + * to be registered after the page loads. */ +$(function() { + $("span.expand-additional-metrics").click(function(){ + // Expand the list of additional metrics. + var additionalMetricsDiv = $(this).parent().find('.additional-metrics'); + $(additionalMetricsDiv).toggleClass('collapsed'); + + // Switch the class of the arrow from open to closed. + $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); + $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); + }); + + stripeSummaryTable(); + + $("input:checkbox").click(function() { + var column = "table ." + $(this).attr("name"); + $(column).toggle(); + stripeSummaryTable(); + }); + + $("#select-all-metrics").click(function() { + if (this.checked) { + // Toggle all un-checked options. + $('input:checkbox:not(:checked)').trigger('click'); + } else { + // Toggle all checked options. + $('input:checkbox:checked').trigger('click'); + } + }); + + // Trigger a click on the checkbox if a user clicks the label next to it. + $("span.additional-metric-title").click(function() { + $(this).parent().find('input:checkbox').trigger('click'); + }); +}); diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js new file mode 100644 index 0000000000000..656147e40d13e --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/table.js @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* Adds background colors to stripe table rows in the summary table (on the stage page). This is + * necessary (instead of using css or the table striping provided by bootstrap) because the summary + * table has hidden rows. + * + * An ID selector (rather than a class selector) is used to ensure this runs quickly even on pages + * with thousands of task rows (ID selectors are much faster than class selectors). */ +function stripeSummaryTable() { + $("#task-summary-table").find("tr:not(:hidden)").each(function (index) { + if (index % 2 == 1) { + $(this).css("background-color", "#f9f9f9"); + } else { + $(this).css("background-color", "#ffffff"); + } + }); +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 445110d63e184..cdf85bfbf326f 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -51,6 +51,11 @@ table.sortable thead { cursor: pointer; } +table.sortable td { + word-wrap: break-word; + max-width: 600px; +} + .progress { margin-bottom: 0px; position: relative } @@ -115,7 +120,57 @@ pre { border: none; } +.stacktrace-details { + max-height: 300px; + overflow-y: auto; + margin: 0; + transition: max-height 0.5s ease-out, padding 0.5s ease-out; +} + +.stacktrace-details.collapsed { + max-height: 0; + padding-top: 0; + padding-bottom: 0; + border: none; +} + +span.expand-additional-metrics { + cursor: pointer; +} + +span.additional-metric-title { + cursor: pointer; +} + +.additional-metrics.collapsed { + display: none; +} + .tooltip { font-weight: normal; } +.arrow-open { + width: 0; + height: 0; + border-left: 5px solid transparent; + border-right: 5px solid transparent; + border-top: 5px solid black; + float: left; + margin-top: 6px; +} + +.arrow-closed { + width: 0; + height: 0; + border-top: 5px solid transparent; + border-bottom: 5px solid transparent; + border-left: 5px solid black; + display: inline-block; +} + +/* Hide all additional metrics by default. This is done here rather than using JavaScript to + * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */ +.scheduler_delay, .gc_time, .deserialization_time, .serialization_time, .getting_result_time { + display: none; +} diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 12f2fe031cb1d..000bbd6b532ad 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -18,12 +18,14 @@ package org.apache.spark import java.io.{ObjectInputStream, Serializable} +import java.util.concurrent.atomic.AtomicLong import scala.collection.generic.Growable import scala.collection.mutable.Map import scala.reflect.ClassTag import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.util.Utils /** * A data type that can be accumulated, ie has an commutative and associative "add" operation, @@ -126,7 +128,7 @@ class Accumulable[R, T] ( } // Called by Java when deserializing an object - private def readObject(in: ObjectInputStream) { + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() value_ = zero deserialized = true @@ -227,6 +229,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa */ class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String]) extends Accumulable[T,T](initialValue, param, name) { + def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None) } @@ -243,6 +246,36 @@ trait AccumulatorParam[T] extends AccumulableParam[T, T] { } } +object AccumulatorParam { + + // The following implicit objects were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, as there are duplicate codes in SparkContext for backward + // compatibility, please update them accordingly if you modify the following implicit objects. + + implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + def addInPlace(t1: Double, t2: Double): Double = t1 + t2 + def zero(initialValue: Double) = 0.0 + } + + implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + def addInPlace(t1: Int, t2: Int): Int = t1 + t2 + def zero(initialValue: Int) = 0 + } + + implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + def addInPlace(t1: Long, t2: Long) = t1 + t2 + def zero(initialValue: Long) = 0L + } + + implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + def addInPlace(t1: Float, t2: Float) = t1 + t2 + def zero(initialValue: Float) = 0f + } + + // TODO: Add AccumulatorParams for other types, e.g. lists and strings +} + // TODO: The multi-thread support in accumulators is kind of lame; check // if there's a more intuitive way of doing it right private object Accumulators { @@ -251,7 +284,7 @@ private object Accumulators { val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]() var lastId: Long = 0 - def newId: Long = synchronized { + def newId(): Long = synchronized { lastId += 1 lastId } diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala index f8584b90cabe6..80da62c44edc5 100644 --- a/core/src/main/scala/org/apache/spark/CacheManager.scala +++ b/core/src/main/scala/org/apache/spark/CacheManager.scala @@ -61,7 +61,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { val computedValues = rdd.computeOrReadCheckpoint(partition, context) // If the task is running locally, do not persist the result - if (context.runningLocally) { + if (context.isRunningLocally) { return computedValues } @@ -168,8 +168,6 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging { arr.iterator.asInstanceOf[Iterator[T]] case Right(it) => // There is not enough space to cache this partition in memory - logWarning(s"Not enough space to cache partition $key in memory! " + - s"Free memory is ${blockManager.memoryStore.freeMemory} bytes.") val returnValues = it.asInstanceOf[Iterator[T]] if (putLevel.useDisk) { logWarning(s"Persisting partition $key to disk instead.") diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala new file mode 100644 index 0000000000000..88adb892998af --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -0,0 +1,516 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import scala.collection.mutable + +import org.apache.spark.scheduler._ + +/** + * An agent that dynamically allocates and removes executors based on the workload. + * + * The add policy depends on whether there are backlogged tasks waiting to be scheduled. If + * the scheduler queue is not drained in N seconds, then new executors are added. If the queue + * persists for another M seconds, then more executors are added and so on. The number added + * in each round increases exponentially from the previous round until an upper bound on the + * number of executors has been reached. The upper bound is based both on a configured property + * and on the number of tasks pending: the policy will never increase the number of executor + * requests past the number needed to handle all pending tasks. + * + * The rationale for the exponential increase is twofold: (1) Executors should be added slowly + * in the beginning in case the number of extra executors needed turns out to be small. Otherwise, + * we may add more executors than we need just to remove them later. (2) Executors should be added + * quickly over time in case the maximum number of executors is very high. Otherwise, it will take + * a long time to ramp up under heavy workloads. + * + * The remove policy is simpler: If an executor has been idle for K seconds, meaning it has not + * been scheduled to run any tasks, then it is removed. + * + * There is no retry logic in either case because we make the assumption that the cluster manager + * will eventually fulfill all requests it receives asynchronously. + * + * The relevant Spark properties include the following: + * + * spark.dynamicAllocation.enabled - Whether this feature is enabled + * spark.dynamicAllocation.minExecutors - Lower bound on the number of executors + * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors + * + * spark.dynamicAllocation.schedulerBacklogTimeout (M) - + * If there are backlogged tasks for this duration, add new executors + * + * spark.dynamicAllocation.sustainedSchedulerBacklogTimeout (N) - + * If the backlog is sustained for this duration, add more executors + * This is used only after the initial backlog timeout is exceeded + * + * spark.dynamicAllocation.executorIdleTimeout (K) - + * If an executor has been idle for this duration, remove it + */ +private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging { + import ExecutorAllocationManager._ + + private val conf = sc.conf + + // Lower and upper bounds on the number of executors. These are required. + private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1) + private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1) + + // How long there must be backlogged tasks for before an addition is triggered + private val schedulerBacklogTimeout = conf.getLong( + "spark.dynamicAllocation.schedulerBacklogTimeout", 60) + + // Same as above, but used only after `schedulerBacklogTimeout` is exceeded + private val sustainedSchedulerBacklogTimeout = conf.getLong( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout) + + // How long an executor must be idle for before it is removed + private val executorIdleTimeout = conf.getLong( + "spark.dynamicAllocation.executorIdleTimeout", 600) + + // During testing, the methods to actually kill and add executors are mocked out + private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false) + + // TODO: The default value of 1 for spark.executor.cores works right now because dynamic + // allocation is only supported for YARN and the default number of cores per executor in YARN is + // 1, but it might need to be attained differently for different cluster managers + private val tasksPerExecutor = + conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1) + + validateSettings() + + // Number of executors to add in the next round + private var numExecutorsToAdd = 1 + + // Number of executors that have been requested but have not registered yet + private var numExecutorsPending = 0 + + // Executors that have been requested to be removed but have not been killed yet + private val executorsPendingToRemove = new mutable.HashSet[String] + + // All known executors + private val executorIds = new mutable.HashSet[String] + + // A timestamp of when an addition should be triggered, or NOT_SET if it is not set + // This is set when pending tasks are added but not scheduled yet + private var addTime: Long = NOT_SET + + // A timestamp for each executor of when the executor should be removed, indexed by the ID + // This is set when an executor is no longer running a task, or when it first registers + private val removeTimes = new mutable.HashMap[String, Long] + + // Polling loop interval (ms) + private val intervalMillis: Long = 100 + + // Clock used to schedule when executors should be added and removed + private var clock: Clock = new RealClock + + // Listener for Spark events that impact the allocation policy + private val listener = new ExecutorAllocationListener(this) + + /** + * Verify that the settings specified through the config are valid. + * If not, throw an appropriate exception. + */ + private def validateSettings(): Unit = { + if (minNumExecutors < 0 || maxNumExecutors < 0) { + throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!") + } + if (minNumExecutors == 0 || maxNumExecutors == 0) { + throw new SparkException("spark.dynamicAllocation.{min/max}Executors cannot be 0!") + } + if (minNumExecutors > maxNumExecutors) { + throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " + + s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!") + } + if (schedulerBacklogTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!") + } + if (sustainedSchedulerBacklogTimeout <= 0) { + throw new SparkException( + "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!") + } + if (executorIdleTimeout <= 0) { + throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!") + } + // Require external shuffle service for dynamic allocation + // Otherwise, we may lose shuffle files when killing executors + if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) { + throw new SparkException("Dynamic allocation of executors requires the external " + + "shuffle service. You may enable this through spark.shuffle.service.enabled.") + } + if (tasksPerExecutor == 0) { + throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores") + } + } + + /** + * Use a different clock for this allocation manager. This is mainly used for testing. + */ + def setClock(newClock: Clock): Unit = { + clock = newClock + } + + /** + * Register for scheduler callbacks to decide when to add and remove executors. + */ + def start(): Unit = { + sc.addSparkListener(listener) + startPolling() + } + + /** + * Start the main polling thread that keeps track of when to add and remove executors. + */ + private def startPolling(): Unit = { + val t = new Thread { + override def run(): Unit = { + while (true) { + try { + schedule() + } catch { + case e: Exception => logError("Exception in dynamic executor allocation thread!", e) + } + Thread.sleep(intervalMillis) + } + } + } + t.setName("spark-dynamic-executor-allocation") + t.setDaemon(true) + t.start() + } + + /** + * If the add time has expired, request new executors and refresh the add time. + * If the remove time for an existing executor has expired, kill the executor. + * This is factored out into its own method for testing. + */ + private def schedule(): Unit = synchronized { + val now = clock.getTimeMillis + if (addTime != NOT_SET && now >= addTime) { + addExecutors() + logDebug(s"Starting timer to add more executors (to " + + s"expire in $sustainedSchedulerBacklogTimeout seconds)") + addTime += sustainedSchedulerBacklogTimeout * 1000 + } + + removeTimes.foreach { case (executorId, expireTime) => + if (now >= expireTime) { + removeExecutor(executorId) + removeTimes.remove(executorId) + } + } + } + + /** + * Request a number of executors from the cluster manager. + * If the cap on the number of executors is reached, give up and reset the + * number of executors to add next round instead of continuing to double it. + * Return the number actually requested. + */ + private def addExecutors(): Int = synchronized { + // Do not request more executors if we have already reached the upper bound + val numExistingExecutors = executorIds.size + numExecutorsPending + if (numExistingExecutors >= maxNumExecutors) { + logDebug(s"Not adding executors because there are already ${executorIds.size} " + + s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)") + numExecutorsToAdd = 1 + return 0 + } + + // The number of executors needed to satisfy all pending tasks is the number of tasks pending + // divided by the number of tasks each executor can fit, rounded up. + val maxNumExecutorsPending = + (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor + if (numExecutorsPending >= maxNumExecutorsPending) { + logDebug(s"Not adding executors because there are already $numExecutorsPending " + + s"pending and pending tasks could only fill $maxNumExecutorsPending") + numExecutorsToAdd = 1 + return 0 + } + + // It's never useful to request more executors than could satisfy all the pending tasks, so + // cap request at that amount. + // Also cap request with respect to the configured upper bound. + val maxNumExecutorsToAdd = math.min( + maxNumExecutorsPending - numExecutorsPending, + maxNumExecutors - numExistingExecutors) + assert(maxNumExecutorsToAdd > 0) + + val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd) + + val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd + val addRequestAcknowledged = testing || sc.requestExecutors(actualNumExecutorsToAdd) + if (addRequestAcknowledged) { + logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " + + s"tasks are backlogged (new desired total will be $newTotalExecutors)") + numExecutorsToAdd = + if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1 + numExecutorsPending += actualNumExecutorsToAdd + actualNumExecutorsToAdd + } else { + logWarning(s"Unable to reach the cluster manager " + + s"to request $actualNumExecutorsToAdd executors!") + 0 + } + } + + /** + * Request the cluster manager to remove the given executor. + * Return whether the request is received. + */ + private def removeExecutor(executorId: String): Boolean = synchronized { + // Do not kill the executor if we are not aware of it (should never happen) + if (!executorIds.contains(executorId)) { + logWarning(s"Attempted to remove unknown executor $executorId!") + return false + } + + // Do not kill the executor again if it is already pending to be killed (should never happen) + if (executorsPendingToRemove.contains(executorId)) { + logWarning(s"Attempted to remove executor $executorId " + + s"when it is already pending to be removed!") + return false + } + + // Do not kill the executor if we have already reached the lower bound + val numExistingExecutors = executorIds.size - executorsPendingToRemove.size + if (numExistingExecutors - 1 < minNumExecutors) { + logInfo(s"Not removing idle executor $executorId because there are only " + + s"$numExistingExecutors executor(s) left (limit $minNumExecutors)") + return false + } + + // Send a request to the backend to kill this executor + val removeRequestAcknowledged = testing || sc.killExecutor(executorId) + if (removeRequestAcknowledged) { + logInfo(s"Removing executor $executorId because it has been idle for " + + s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})") + executorsPendingToRemove.add(executorId) + true + } else { + logWarning(s"Unable to reach the cluster manager to kill executor $executorId!") + false + } + } + + /** + * Callback invoked when the specified executor has been added. + */ + private def onExecutorAdded(executorId: String): Unit = synchronized { + if (!executorIds.contains(executorId)) { + executorIds.add(executorId) + executorIds.foreach(onExecutorIdle) + logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})") + if (numExecutorsPending > 0) { + numExecutorsPending -= 1 + logDebug(s"Decremented number of pending executors ($numExecutorsPending left)") + } + } else { + logWarning(s"Duplicate executor $executorId has registered") + } + } + + /** + * Callback invoked when the specified executor has been removed. + */ + private def onExecutorRemoved(executorId: String): Unit = synchronized { + if (executorIds.contains(executorId)) { + executorIds.remove(executorId) + removeTimes.remove(executorId) + logInfo(s"Existing executor $executorId has been removed (new total is ${executorIds.size})") + if (executorsPendingToRemove.contains(executorId)) { + executorsPendingToRemove.remove(executorId) + logDebug(s"Executor $executorId is no longer pending to " + + s"be removed (${executorsPendingToRemove.size} left)") + } + } else { + logWarning(s"Unknown executor $executorId has been removed!") + } + } + + /** + * Callback invoked when the scheduler receives new pending tasks. + * This sets a time in the future that decides when executors should be added + * if it is not already set. + */ + private def onSchedulerBacklogged(): Unit = synchronized { + if (addTime == NOT_SET) { + logDebug(s"Starting timer to add executors because pending tasks " + + s"are building up (to expire in $schedulerBacklogTimeout seconds)") + addTime = clock.getTimeMillis + schedulerBacklogTimeout * 1000 + } + } + + /** + * Callback invoked when the scheduler queue is drained. + * This resets all variables used for adding executors. + */ + private def onSchedulerQueueEmpty(): Unit = synchronized { + logDebug(s"Clearing timer to add executors because there are no more pending tasks") + addTime = NOT_SET + numExecutorsToAdd = 1 + } + + /** + * Callback invoked when the specified executor is no longer running any tasks. + * This sets a time in the future that decides when this executor should be removed if + * the executor is not already marked as idle. + */ + private def onExecutorIdle(executorId: String): Unit = synchronized { + if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { + logDebug(s"Starting idle timer for $executorId because there are no more tasks " + + s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)") + removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000 + } + } + + /** + * Callback invoked when the specified executor is now running a task. + * This resets all variables used for removing this executor. + */ + private def onExecutorBusy(executorId: String): Unit = synchronized { + logDebug(s"Clearing idle timer for $executorId because it is now running a task") + removeTimes.remove(executorId) + } + + /** + * A listener that notifies the given allocation manager of when to add and remove executors. + * + * This class is intentionally conservative in its assumptions about the relative ordering + * and consistency of events returned by the listener. For simplicity, it does not account + * for speculated tasks. + */ + private class ExecutorAllocationListener(allocationManager: ExecutorAllocationManager) + extends SparkListener { + + private val stageIdToNumTasks = new mutable.HashMap[Int, Int] + private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]] + private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]] + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + synchronized { + val stageId = stageSubmitted.stageInfo.stageId + val numTasks = stageSubmitted.stageInfo.numTasks + stageIdToNumTasks(stageId) = numTasks + allocationManager.onSchedulerBacklogged() + } + } + + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + synchronized { + val stageId = stageCompleted.stageInfo.stageId + stageIdToNumTasks -= stageId + stageIdToTaskIndices -= stageId + + // If this is the last stage with pending tasks, mark the scheduler queue as empty + // This is needed in case the stage is aborted for any reason + if (stageIdToNumTasks.isEmpty) { + allocationManager.onSchedulerQueueEmpty() + } + } + } + + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized { + val stageId = taskStart.stageId + val taskId = taskStart.taskInfo.taskId + val taskIndex = taskStart.taskInfo.index + val executorId = taskStart.taskInfo.executorId + + // If this is the last pending task, mark the scheduler queue as empty + stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex + val numTasksScheduled = stageIdToTaskIndices(stageId).size + val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1) + if (numTasksScheduled == numTasksTotal) { + // No more pending tasks for this stage + stageIdToNumTasks -= stageId + if (stageIdToNumTasks.isEmpty) { + allocationManager.onSchedulerQueueEmpty() + } + } + + // Mark the executor on which this task is scheduled as busy + executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId + allocationManager.onExecutorBusy(executorId) + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { + val executorId = taskEnd.taskInfo.executorId + val taskId = taskEnd.taskInfo.taskId + + // If the executor is no longer running scheduled any tasks, mark it as idle + if (executorIdToTaskIds.contains(executorId)) { + executorIdToTaskIds(executorId) -= taskId + if (executorIdToTaskIds(executorId).isEmpty) { + executorIdToTaskIds -= executorId + allocationManager.onExecutorIdle(executorId) + } + } + } + + override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = { + val executorId = blockManagerAdded.blockManagerId.executorId + if (executorId != SparkContext.DRIVER_IDENTIFIER) { + allocationManager.onExecutorAdded(executorId) + } + } + + override def onBlockManagerRemoved( + blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = { + allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId) + } + + /** + * An estimate of the total number of pending tasks remaining for currently running stages. Does + * not account for tasks which may have failed and been resubmitted. + */ + def totalPendingTasks(): Int = { + stageIdToNumTasks.map { case (stageId, numTasks) => + numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0) + }.sum + } + } + +} + +private object ExecutorAllocationManager { + val NOT_SET = Long.MaxValue +} + +/** + * An abstract clock for measuring elapsed time. + */ +private trait Clock { + def getTimeMillis: Long +} + +/** + * A clock backed by a monotonically increasing time source. + * The time returned by this clock does not correspond to any notion of wall-clock time. + */ +private class RealClock extends Clock { + override def getTimeMillis: Long = System.nanoTime / (1000 * 1000) +} + +/** + * A clock that allows the caller to customize the time. + * This is used mainly for testing. + */ +private class TestClock(startTimeMillis: Long) extends Clock { + private var time: Long = startTimeMillis + override def getTimeMillis: Long = time + def tick(ms: Long): Unit = { time += ms } +} diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index e8f761eaa5799..e97a7375a267b 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -17,20 +17,21 @@ package org.apache.spark -import scala.concurrent._ -import scala.concurrent.duration.Duration -import scala.util.Try +import java.util.Collections +import java.util.concurrent.TimeUnit -import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} +import scala.concurrent._ +import scala.concurrent.duration.Duration +import scala.util.{Failure, Try} + /** - * :: Experimental :: * A future for the result of an action to support cancellation. This is an extension of the * Scala Future interface to support cancellation. */ -@Experimental trait FutureAction[T] extends Future[T] { // Note that we redefine methods of the Future trait here explicitly so we can specify a different // documentation (with reference to the word "action"). @@ -69,6 +70,11 @@ trait FutureAction[T] extends Future[T] { */ override def isCompleted: Boolean + /** + * Returns whether the action has been cancelled. + */ + def isCancelled: Boolean + /** * The value of this Future. * @@ -96,15 +102,16 @@ trait FutureAction[T] extends Future[T] { /** - * :: Experimental :: * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ -@Experimental class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { + @volatile private var _cancelled: Boolean = false + override def cancel() { + _cancelled = true jobWaiter.cancel() } @@ -143,6 +150,8 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: } override def isCompleted: Boolean = jobWaiter.jobFinished + + override def isCancelled: Boolean = _cancelled override def value: Option[Try[T]] = { if (jobWaiter.jobFinished) { @@ -164,12 +173,10 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: /** - * :: Experimental :: * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the * action thread if it is being blocked by a job. */ -@Experimental class ComplexFutureAction[T] extends FutureAction[T] { // Pointer to the thread that is executing the action. It is set when the action is run. @@ -203,7 +210,11 @@ class ComplexFutureAction[T] extends FutureAction[T] { } catch { case e: Exception => p.failure(e) } finally { - thread = null + // This lock guarantees when calling `thread.interrupt()` in `cancel`, + // thread won't be set to null. + ComplexFutureAction.this.synchronized { + thread = null + } } } this @@ -222,7 +233,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { // If the action hasn't been cancelled yet, submit the job. The check and the submitJob // command need to be in an atomic block. val job = this.synchronized { - if (!cancelled) { + if (!isCancelled) { rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc) } else { throw new SparkException("Action has been cancelled") @@ -243,10 +254,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { } } - /** - * Returns whether the promise has been cancelled. - */ - def cancelled: Boolean = _cancelled + override def isCancelled: Boolean = _cancelled @throws(classOf[InterruptedException]) @throws(classOf[scala.concurrent.TimeoutException]) @@ -271,3 +279,55 @@ class ComplexFutureAction[T] extends FutureAction[T] { def jobIds = jobs } + +private[spark] +class JavaFutureActionWrapper[S, T](futureAction: FutureAction[S], converter: S => T) + extends JavaFutureAction[T] { + + import scala.collection.JavaConverters._ + + override def isCancelled: Boolean = futureAction.isCancelled + + override def isDone: Boolean = { + // According to java.util.Future's Javadoc, this returns True if the task was completed, + // whether that completion was due to successful execution, an exception, or a cancellation. + futureAction.isCancelled || futureAction.isCompleted + } + + override def jobIds(): java.util.List[java.lang.Integer] = { + Collections.unmodifiableList(futureAction.jobIds.map(Integer.valueOf).asJava) + } + + private def getImpl(timeout: Duration): T = { + // This will throw TimeoutException on timeout: + Await.ready(futureAction, timeout) + futureAction.value.get match { + case scala.util.Success(value) => converter(value) + case Failure(exception) => + if (isCancelled) { + throw new CancellationException("Job cancelled").initCause(exception) + } else { + // java.util.Future.get() wraps exceptions in ExecutionException + throw new ExecutionException("Exception thrown by job", exception) + } + } + } + + override def get(): T = getImpl(Duration.Inf) + + override def get(timeout: Long, unit: TimeUnit): T = + getImpl(Duration.fromNanos(unit.toNanos(timeout))) + + override def cancel(mayInterruptIfRunning: Boolean): Boolean = synchronized { + if (isDone) { + // According to java.util.Future's Javadoc, this should return false if the task is completed. + false + } else { + // We're limited in terms of the semantics we can provide here; our cancellation is + // asynchronous and doesn't provide a mechanism to not cancel if the job is running. + futureAction.cancel() + true + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4cb0bd4142435..7d96962c4acd7 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -178,6 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } } else { + logError("Missing all output locations for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId) } @@ -348,7 +349,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr new ConcurrentHashMap[Int, Array[MapStatus]] } -private[spark] object MapOutputTracker { +private[spark] object MapOutputTracker extends Logging { // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will @@ -381,6 +382,7 @@ private[spark] object MapOutputTracker { statuses.map { status => if (status == null) { + logError("Missing an output location for shuffle " + shuffleId) throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) } else { diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 37053bb6f37ad..e53a78ead2c0e 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -204,7 +204,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } @throws(classOf[IOException]) - private def writeObject(out: ObjectOutputStream) { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => out.defaultWriteObject() @@ -222,7 +222,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream) { + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => in.defaultReadObject() diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 0e0f1a7b2377e..dbff9d12b5ad7 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication} import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.network.sasl.SecretKeyHolder /** * Spark class responsible for security. @@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * Authenticator installed in the SecurityManager to how it does the authentication * and in this case gets the user name and password from the request. * - * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously * exchange messages. For this we use the Java SASL * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 * as the authentication mechanism. This means the shared secret is not passed @@ -98,7 +99,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * of protection they want. If we support those, the messages will also have to * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. * - * Since the connectionManager does asynchronous messages passing, the SASL + * Since the NioBlockTransferService does asynchronous messages passing, the SASL * authentication is a bit more complex. A ConnectionManager can be both a client * and a Server, so for a particular connection is has to determine what to do. * A ConnectionId was added to be able to track connections and is used to @@ -107,6 +108,10 @@ import org.apache.spark.deploy.SparkHadoopUtil * and waits for the response from the server and does the handshake before sending * the real message. * + * The NettyBlockTransferService ensures that SASL authentication is performed + * synchronously prior to any other communication on a connection. This is done in + * SaslClientBootstrap on the client side and SaslRpcHandler on the server side. + * * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work * properly. For non-Yarn deployments, users can write a filter to go through a @@ -139,7 +144,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * can take place. */ -private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { +private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder { // key used to store the spark secret in the Hadoop UGI private val sparkSecretLookupKey = "sparkCookie" @@ -337,4 +342,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { * @return the secret key as a String if authentication is enabled, otherwise returns null */ def getSecretKey(): String = secretKey + + // Default SecurityManager only has a single secret key, so ignore appId. + override def getSaslUser(appId: String): String = getSaslUser() + override def getSecretKey(appId: String): String = getSecretKey() } diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index e50b9ac2291f9..55cb25946c2ad 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -24,18 +24,19 @@ import org.apache.hadoop.io.ObjectWritable import org.apache.hadoop.io.Writable import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Utils @DeveloperApi class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { def value = t override def toString = t.toString - private def writeObject(out: ObjectOutputStream) { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { out.defaultWriteObject() new ObjectWritable(t).write(out) } - private def readObject(in: ObjectInputStream) { + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() val ow = new ObjectWritable() ow.setConf(new Configuration()) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 605df0e929faa..4c6c86c7bad78 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -18,7 +18,8 @@ package org.apache.spark import scala.collection.JavaConverters._ -import scala.collection.mutable.HashMap +import scala.collection.mutable.{HashMap, LinkedHashSet} +import org.apache.spark.serializer.KryoSerializer /** * Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. @@ -140,6 +141,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { this } + /** + * Use Kryo serialization and register the given set of classes with Kryo. + * If called multiple times, this will append the classes from all calls together. + */ + def registerKryoClasses(classes: Array[Class[_]]): SparkConf = { + val allClassNames = new LinkedHashSet[String]() + allClassNames ++= get("spark.kryo.classesToRegister", "").split(',').filter(!_.isEmpty) + allClassNames ++= classes.map(_.getName) + + set("spark.kryo.classesToRegister", allClassNames.mkString(",")) + set("spark.serializer", classOf[KryoSerializer].getName) + this + } + /** Remove a parameter from the configuration */ def remove(key: String): SparkConf = { settings.remove(key) @@ -202,6 +217,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { */ getAll.filter { case (k, _) => isAkkaConf(k) } + /** + * Returns the Spark application id, valid in the Driver after TaskScheduler registration and + * from the start in the Executor. + */ + def getAppId: String = get("spark.app.id") + /** Does the configuration contain a given parameter? */ def contains(key: String): Boolean = settings.contains(key) @@ -229,6 +250,19 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { val executorClasspathKey = "spark.executor.extraClassPath" val driverOptsKey = "spark.driver.extraJavaOptions" val driverClassPathKey = "spark.driver.extraClassPath" + val driverLibraryPathKey = "spark.driver.extraLibraryPath" + + // Used by Yarn in 1.1 and before + sys.props.get("spark.driver.libraryPath").foreach { value => + val warning = + s""" + |spark.driver.libraryPath was detected (set to '$value'). + |This is deprecated in Spark 1.2+. + | + |Please instead use: $driverLibraryPathKey + """.stripMargin + logWarning(warning) + } // Validate spark.executor.extraJavaOptions settings.get(executorOptsKey).map { javaOpts => diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 97109b9f41b60..9b0d5be7a7ab2 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -21,8 +21,8 @@ import scala.language.implicitConversions import java.io._ import java.net.URI +import java.util.{Arrays, Properties, UUID} import java.util.concurrent.atomic.AtomicInteger -import java.util.{Properties, UUID} import java.util.UUID.randomUUID import scala.collection.{Map, Set} import scala.collection.JavaConversions._ @@ -41,7 +41,8 @@ import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.WholeTextFileInputFormat +import org.apache.spark.executor.TriggerThreadDump +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ @@ -49,24 +50,41 @@ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkD import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ -import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} +import org.apache.spark.ui.{SparkUI, ConsoleProgressBar} +import org.apache.spark.ui.jobs.JobProgressListener +import org.apache.spark.util._ /** * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. + * * @param config a Spark Config object describing the application configuration. Any settings in * this config overrides the default configs as well as system properties. */ - class SparkContext(config: SparkConf) extends Logging { + // The call site where this SparkContext was constructed. + private val creationSite: CallSite = Utils.getCallSite() + + // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active + private val allowMultipleContexts: Boolean = + config.getBoolean("spark.driver.allowMultipleContexts", false) + + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having started construction. + // NOTE: this must be placed at the beginning of the SparkContext constructor. + SparkContext.markPartiallyConstructed(this, allowMultipleContexts) + // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It // contains a map from hostname to a list of input format splits on the host. private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map() + val startTime = System.currentTimeMillis() + /** * Create a SparkContext that loads settings from system properties (for instance, when * launching with ./bin/spark-submit). @@ -208,16 +226,10 @@ class SparkContext(config: SparkConf) extends Logging { // An asynchronous listener bus for Spark events private[spark] val listenerBus = new LiveListenerBus - // Create the Spark execution environment (cache, map output tracker, etc) conf.set("spark.executor.id", "driver") - private[spark] val env = SparkEnv.create( - conf, - "", - conf.get("spark.driver.host"), - conf.get("spark.driver.port").toInt, - isDriver = true, - isLocal = isLocal, - listenerBus = listenerBus) + + // Create the Spark execution environment (cache, map output tracker, etc) + private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus) SparkEnv.set(env) // Used to store a URL for each static file/jar together with the file's local timestamp @@ -229,21 +241,36 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) - // Initialize the Spark UI, registering all associated listeners + + private[spark] val jobProgressListener = new JobProgressListener(conf) + listenerBus.addListener(jobProgressListener) + + val statusTracker = new SparkStatusTracker(this) + + private[spark] val progressBar: Option[ConsoleProgressBar] = + if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + Some(new ConsoleProgressBar(this)) + } else { + None + } + + // Initialize the Spark UI private[spark] val ui: Option[SparkUI] = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(new SparkUI(this)) + Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener, + env.securityManager,appName)) } else { // For tests, do not enable the UI None } + + // Bind the UI before starting the task scheduler to communicate + // the bound port to the cluster manager properly ui.foreach(_.bind()) /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) - val startTime = System.currentTimeMillis() - // Add each JAR given through the constructor if (jars != null) { jars.foreach(addJar) @@ -291,7 +318,8 @@ class SparkContext(config: SparkConf) extends Logging { executorEnvs("SPARK_USER") = sparkUser // Create and start the scheduler - private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master) + private[spark] var (schedulerBackend, taskScheduler) = + SparkContext.createTaskScheduler(this, master) private val heartbeatReceiver = env.actorSystem.actorOf( Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") @volatile private[spark] var dagScheduler: DAGScheduler = _ @@ -309,6 +337,8 @@ class SparkContext(config: SparkConf) extends Logging { val applicationId: String = taskScheduler.applicationId() conf.set("spark.app.id", applicationId) + env.blockManager.initialize(applicationId) + val metricsSystem = env.metricsSystem // The metrics system for Driver need to be set spark.app.id to app ID. @@ -326,6 +356,15 @@ class SparkContext(config: SparkConf) extends Logging { } else None } + // Optionally scale number of executors dynamically based on workload. Exposed for testing. + private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] = + if (conf.getBoolean("spark.dynamicAllocation.enabled", false)) { + Some(new ExecutorAllocationManager(this)) + } else { + None + } + executorAllocationManager.foreach(_.start()) + // At this point, all relevant SparkListeners have been registered, so begin releasing events listenerBus.start() @@ -348,6 +387,29 @@ class SparkContext(config: SparkConf) extends Logging { override protected def childValue(parent: Properties): Properties = new Properties(parent) } + /** + * Called by the web UI to obtain executor thread dumps. This method may be expensive. + * Logs an error and returns None if we failed to obtain a thread dump, which could occur due + * to an executor being dead or unresponsive or due to network issues while sending the thread + * dump message back to the driver. + */ + private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = { + try { + if (executorId == SparkContext.DRIVER_IDENTIFIER) { + Some(Utils.getThreadDump()) + } else { + val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get + val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) + Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, + AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + } + } catch { + case e: Exception => + logError(s"Exception getting thread dump from executor $executorId", e) + None + } + } + private[spark] def getLocalProperties: Properties = localProperties.get() private[spark] def setLocalProperties(props: Properties) { @@ -520,6 +582,73 @@ class SparkContext(config: SparkConf) extends Logging { minPartitions).setName(path) } + + /** + * :: Experimental :: + * + * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file + * (useful for binary data) + * + * For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + * then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. + * + * @note Small files are preferred; very large files may cause bad performance. + */ + @Experimental + def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): + RDD[(String, PortableDataStream)] = { + val job = new NewHadoopJob(hadoopConfiguration) + NewFileInputFormat.addInputPath(job, new Path(path)) + val updateConf = job.getConfiguration + new BinaryFileRDD( + this, + classOf[StreamInputFormat], + classOf[String], + classOf[PortableDataStream], + updateConf, + minPartitions).setName(path) + } + + /** + * :: Experimental :: + * + * Load data from a flat binary file, assuming the length of each record is constant. + * + * @param path Directory to the input data files + * @param recordLength The length at which to split the records + * @return An RDD of data with values, represented as byte arrays + */ + @Experimental + def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration) + : RDD[Array[Byte]] = { + conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength) + val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, + classOf[FixedLengthBinaryInputFormat], + classOf[LongWritable], + classOf[BytesWritable], + conf=conf) + val data = br.map{ case (k, v) => v.getBytes} + data + } + /** * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), @@ -779,20 +908,20 @@ class SparkContext(config: SparkConf) extends Logging { /** * Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values * with `+=`. Only the driver can access the accumuable's `value`. - * @tparam T accumulator type - * @tparam R type that can be added to the accumulator + * @tparam R accumulator result type + * @tparam T type that can be added to the accumulator */ - def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) = + def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = new Accumulable(initialValue, param) /** * Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the * Spark UI. Tasks can add values to the accumuable using the `+=` operator. Only the driver can * access the accumuable's `value`. - * @tparam T accumulator type - * @tparam R type that can be added to the accumulator + * @tparam R accumulator result type + * @tparam T type that can be added to the accumulator */ - def accumulable[T, R](initialValue: T, name: String)(implicit param: AccumulableParam[T, R]) = + def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = new Accumulable(initialValue, param, Some(name)) /** @@ -814,6 +943,8 @@ class SparkContext(config: SparkConf) extends Logging { */ def broadcast[T: ClassTag](value: T): Broadcast[T] = { val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + val callSite = getCallSite + logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm) cleaner.foreach(_.registerBroadcastForCleanup(bc)) bc } @@ -831,11 +962,12 @@ class SparkContext(config: SparkConf) extends Logging { case "local" => "file:" + uri.getPath case _ => path } - addedFiles(key) = System.currentTimeMillis + val timestamp = System.currentTimeMillis + addedFiles(key) = timestamp // Fetch the file locally in case a job is executed using DAGScheduler.runLocally(). Utils.fetchFile(path, new File(SparkFiles.getRootDirectory()), conf, env.securityManager, - hadoopConfiguration) + hadoopConfiguration, timestamp, useCache = false) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) postEnvironmentUpdate() @@ -850,6 +982,46 @@ class SparkContext(config: SparkConf) extends Logging { listenerBus.addListener(listener) } + /** + * :: DeveloperApi :: + * Request an additional number of executors from the cluster manager. + * This is currently only supported in Yarn mode. Return whether the request is received. + */ + @DeveloperApi + def requestExecutors(numAdditionalExecutors: Int): Boolean = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.requestExecutors(numAdditionalExecutors) + case _ => + logWarning("Requesting executors is only supported in coarse-grained mode") + false + } + } + + /** + * :: DeveloperApi :: + * Request that the cluster manager kill the specified executors. + * This is currently only supported in Yarn mode. Return whether the request is received. + */ + @DeveloperApi + def killExecutors(executorIds: Seq[String]): Boolean = { + schedulerBackend match { + case b: CoarseGrainedSchedulerBackend => + b.killExecutors(executorIds) + case _ => + logWarning("Killing executors is only supported in coarse-grained mode") + false + } + } + + /** + * :: DeveloperApi :: + * Request that cluster manager the kill the specified executor. + * This is currently only supported in Yarn mode. Return whether the request is received. + */ + @DeveloperApi + def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId)) + /** The version of Spark on which this application is running. */ def version = SPARK_VERSION @@ -1015,27 +1187,30 @@ class SparkContext(config: SparkConf) extends Logging { /** Shut down the SparkContext. */ def stop() { - postApplicationEnd() - ui.foreach(_.stop()) - // Do this only if not stopped already - best case effort. - // prevent NPE if stopped more than once. - val dagSchedulerCopy = dagScheduler - dagScheduler = null - if (dagSchedulerCopy != null) { - env.metricsSystem.report() - metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) - cleaner.foreach(_.stop()) - dagSchedulerCopy.stop() - taskScheduler = null - // TODO: Cache.stop()? - env.stop() - SparkEnv.set(null) - listenerBus.stop() - eventLogger.foreach(_.stop()) - logInfo("Successfully stopped SparkContext") - } else { - logInfo("SparkContext already stopped") + SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + postApplicationEnd() + ui.foreach(_.stop()) + // Do this only if not stopped already - best case effort. + // prevent NPE if stopped more than once. + val dagSchedulerCopy = dagScheduler + dagScheduler = null + if (dagSchedulerCopy != null) { + env.metricsSystem.report() + metadataCleaner.cancel() + env.actorSystem.stop(heartbeatReceiver) + cleaner.foreach(_.stop()) + dagSchedulerCopy.stop() + taskScheduler = null + // TODO: Cache.stop()? + env.stop() + SparkEnv.set(null) + listenerBus.stop() + eventLogger.foreach(_.stop()) + logInfo("Successfully stopped SparkContext") + SparkContext.clearActiveContext() + } else { + logInfo("SparkContext already stopped") + } } } @@ -1106,6 +1281,7 @@ class SparkContext(config: SparkConf) extends Logging { logInfo("Starting job: " + callSite.shortForm) dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal, resultHandler, localProperties.get) + progressBar.foreach(_.finishAll()) rdd.doCheckpoint() } @@ -1324,6 +1500,11 @@ class SparkContext(config: SparkConf) extends Logging { private[spark] def cleanup(cleanupTime: Long) { persistentRdds.clearOldValues(cleanupTime) } + + // In order to prevent multiple SparkContexts from being active at the same time, mark this + // context as having finished construction. + // NOTE: this must be placed at the end of the SparkContext constructor. + SparkContext.setActiveContext(this, allowMultipleContexts) } /** @@ -1332,6 +1513,107 @@ class SparkContext(config: SparkConf) extends Logging { */ object SparkContext extends Logging { + /** + * Lock that guards access to global variables that track SparkContext construction. + */ + private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object() + + /** + * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var activeContext: Option[SparkContext] = None + + /** + * Points to a partially-constructed SparkContext if some thread is in the SparkContext + * constructor, or `None` if no SparkContext is being constructed. + * + * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK + */ + private var contextBeingConstructed: Option[SparkContext] = None + + /** + * Called to ensure that no other SparkContext is running in this JVM. + * + * Throws an exception if a running context is detected and logs a warning if another thread is + * constructing a SparkContext. This warning is necessary because the current locking scheme + * prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private def assertNoOtherContextIsRunning( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + contextBeingConstructed.foreach { otherContext => + if (otherContext ne sc) { // checks for reference equality + // Since otherContext might point to a partially-constructed context, guard against + // its creationSite field being null: + val otherContextCreationSite = + Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location") + val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" + + " constructor). This may indicate an error, since only one SparkContext may be" + + " running in this JVM (see SPARK-2243)." + + s" The other SparkContext was created at:\n$otherContextCreationSite" + logWarning(warnMsg) + } + + activeContext.foreach { ctx => + val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." + + " To ignore this error, set spark.driver.allowMultipleContexts = true. " + + s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}" + val exception = new SparkException(errMsg) + if (allowMultipleContexts) { + logWarning("Multiple running SparkContexts detected in the same JVM!", exception) + } else { + throw exception + } + } + } + } + } + + /** + * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is + * running. Throws an exception if a running context is detected and logs a warning if another + * thread is constructing a SparkContext. This warning is necessary because the current locking + * scheme prevents us from reliably distinguishing between cases where another context is being + * constructed and cases where another constructor threw an exception. + */ + private[spark] def markPartiallyConstructed( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = Some(sc) + } + } + + /** + * Called at the end of the SparkContext constructor to ensure that no other SparkContext has + * raced with this constructor and started. + */ + private[spark] def setActiveContext( + sc: SparkContext, + allowMultipleContexts: Boolean): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + assertNoOtherContextIsRunning(sc, allowMultipleContexts) + contextBeingConstructed = None + activeContext = Some(sc) + } + } + + /** + * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's + * also called in unit tests to prevent a flood of warnings from test suites that don't / can't + * properly clean up their SparkContexts. + */ + private[spark] def clearActiveContext(): Unit = { + SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized { + activeContext = None + } + } + private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" @@ -1340,47 +1622,76 @@ object SparkContext extends Logging { private[spark] val SPARK_UNKNOWN_USER = "" - implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { + private[spark] val DRIVER_IDENTIFIER = "" + + // The following deprecated objects have already been copied to `object AccumulatorParam` to + // make the compiler find them automatically. They are duplicate codes only for backward + // compatibility, please update `object AccumulatorParam` accordingly if you plan to modify the + // following ones. + + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 } - implicit object IntAccumulatorParam extends AccumulatorParam[Int] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object IntAccumulatorParam extends AccumulatorParam[Int] { def addInPlace(t1: Int, t2: Int): Int = t1 + t2 def zero(initialValue: Int) = 0 } - implicit object LongAccumulatorParam extends AccumulatorParam[Long] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object LongAccumulatorParam extends AccumulatorParam[Long] { def addInPlace(t1: Long, t2: Long) = t1 + t2 def zero(initialValue: Long) = 0L } - implicit object FloatAccumulatorParam extends AccumulatorParam[Float] { + @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + + "backward compatibility.", "1.2.0") + object FloatAccumulatorParam extends AccumulatorParam[Float] { def addInPlace(t1: Float, t2: Float) = t1 + t2 def zero(initialValue: Float) = 0f } - // TODO: Add AccumulatorParams for other types, e.g. lists and strings + // The following deprecated functions have already been moved to `object RDD` to + // make the compiler find them automatically. They are still kept here for backward compatibility + // and just call the corresponding functions in `object RDD`. - implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { - new PairRDDFunctions(rdd) + RDD.rddToPairRDDFunctions(rdd) } - implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = RDD.rddToAsyncRDDActions(rdd) - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( rdd: RDD[(K, V)]) = - new SequenceFileRDDFunctions(rdd) + RDD.rddToSequenceFileRDDFunctions(rdd) - implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( rdd: RDD[(K, V)]) = - new OrderedRDDFunctions[K, V, (K, V)](rdd) + RDD.rddToOrderedRDDFunctions(rdd) - implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = RDD.doubleRDDToDoubleRDDFunctions(rdd) - implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = - new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) + @deprecated("Replaced by implicit functions in the RDD companion object. This is " + + "kept here only for backward compatibility.", "1.2.0") + def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + RDD.numericRDDToDoubleRDDFunctions(rdd) // Implicit conversions to common Writable types, for saveAsSequenceFile @@ -1406,37 +1717,49 @@ object SparkContext extends Logging { arr.map(x => anyToWritable(x)).toArray) } - // Helper objects for converting common types to Writable - private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) - : WritableConverter[T] = { - val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] - new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) - } + // The following deprecated functions have already been moved to `object WritableConverter` to + // make the compiler find them automatically. They are still kept here for backward compatibility + // and just call the corresponding functions in `object WritableConverter`. - implicit def intWritableConverter(): WritableConverter[Int] = - simpleWritableConverter[Int, IntWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def intWritableConverter(): WritableConverter[Int] = + WritableConverter.intWritableConverter() - implicit def longWritableConverter(): WritableConverter[Long] = - simpleWritableConverter[Long, LongWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def longWritableConverter(): WritableConverter[Long] = + WritableConverter.longWritableConverter() - implicit def doubleWritableConverter(): WritableConverter[Double] = - simpleWritableConverter[Double, DoubleWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def doubleWritableConverter(): WritableConverter[Double] = + WritableConverter.doubleWritableConverter() - implicit def floatWritableConverter(): WritableConverter[Float] = - simpleWritableConverter[Float, FloatWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def floatWritableConverter(): WritableConverter[Float] = + WritableConverter.floatWritableConverter() - implicit def booleanWritableConverter(): WritableConverter[Boolean] = - simpleWritableConverter[Boolean, BooleanWritable](_.get) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def booleanWritableConverter(): WritableConverter[Boolean] = + WritableConverter.booleanWritableConverter() - implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { - simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes) - } + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def bytesWritableConverter(): WritableConverter[Array[Byte]] = + WritableConverter.bytesWritableConverter() - implicit def stringWritableConverter(): WritableConverter[String] = - simpleWritableConverter[String, Text](_.toString) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def stringWritableConverter(): WritableConverter[String] = + WritableConverter.stringWritableConverter() - implicit def writableWritableConverter[T <: Writable]() = - new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) + @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + + "backward compatibility.", "1.2.0") + def writableWritableConverter[T <: Writable]() = + WritableConverter.writableWritableConverter() /** * Find the JAR from which a given class was loaded, to make it easy for users to pass @@ -1492,8 +1815,13 @@ object SparkContext extends Logging { res } - /** Creates a task scheduler based on a given master URL. Extracted for testing. */ - private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = { + /** + * Create a task scheduler based on a given master URL. + * Return a 2-tuple of the scheduler backend and the task scheduler. + */ + private def createTaskScheduler( + sc: SparkContext, + master: String): (SchedulerBackend, TaskScheduler) = { // Regular expression used for local[N] and local[*] master formats val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks @@ -1515,16 +1843,19 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, 1) scheduler.initialize(backend) - scheduler + (backend, scheduler) case LOCAL_N_REGEX(threads) => def localCpuCount = Runtime.getRuntime.availableProcessors() // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. val threadCount = if (threads == "*") localCpuCount else threads.toInt + if (threadCount <= 0) { + throw new SparkException(s"Asked to run locally with $threadCount threads") + } val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) - scheduler + (backend, scheduler) case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => def localCpuCount = Runtime.getRuntime.availableProcessors() @@ -1534,14 +1865,14 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true) val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) - scheduler + (backend, scheduler) case SPARK_REGEX(sparkUrl) => val scheduler = new TaskSchedulerImpl(sc) val masterUrls = sparkUrl.split(",").map("spark://" + _) val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) - scheduler + (backend, scheduler) case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. @@ -1561,7 +1892,7 @@ object SparkContext extends Logging { backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { localCluster.stop() } - scheduler + (backend, scheduler) case "yarn-standalone" | "yarn-cluster" => if (master == "yarn-standalone") { @@ -1590,7 +1921,7 @@ object SparkContext extends Logging { } } scheduler.initialize(backend) - scheduler + (backend, scheduler) case "yarn-client" => val scheduler = try { @@ -1617,7 +1948,7 @@ object SparkContext extends Logging { } scheduler.initialize(backend) - scheduler + (backend, scheduler) case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() @@ -1630,13 +1961,13 @@ object SparkContext extends Logging { new MesosSchedulerBackend(scheduler, sc, url) } scheduler.initialize(backend) - scheduler + (backend, scheduler) case SIMR_REGEX(simrUrl) => val scheduler = new TaskSchedulerImpl(sc) val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) scheduler.initialize(backend) - scheduler + (backend, scheduler) case _ => throw new SparkException("Could not parse Master URL: '" + master + "'") @@ -1655,3 +1986,46 @@ private[spark] class WritableConverter[T]( val writableClass: ClassTag[T] => Class[_ <: Writable], val convert: Writable => T) extends Serializable + +object WritableConverter { + + // Helper objects for converting common types to Writable + private[spark] def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) + : WritableConverter[T] = { + val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] + new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) + } + + // The following implicit functions were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, we still keep the old functions in SparkContext for backward + // compatibility and forward to the following functions directly. + + implicit def intWritableConverter(): WritableConverter[Int] = + simpleWritableConverter[Int, IntWritable](_.get) + + implicit def longWritableConverter(): WritableConverter[Long] = + simpleWritableConverter[Long, LongWritable](_.get) + + implicit def doubleWritableConverter(): WritableConverter[Double] = + simpleWritableConverter[Double, DoubleWritable](_.get) + + implicit def floatWritableConverter(): WritableConverter[Float] = + simpleWritableConverter[Float, FloatWritable](_.get) + + implicit def booleanWritableConverter(): WritableConverter[Boolean] = + simpleWritableConverter[Boolean, BooleanWritable](_.get) + + implicit def bytesWritableConverter(): WritableConverter[Array[Byte]] = { + simpleWritableConverter[Array[Byte], BytesWritable](bw => + // getBytes method returns array which is longer then data to be returned + Arrays.copyOfRange(bw.getBytes, 0, bw.getLength) + ) + } + + implicit def stringWritableConverter(): WritableConverter[String] = + simpleWritableConverter[String, Text](_.toString) + + implicit def writableWritableConverter[T <: Writable]() = + new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) +} diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72cac42cd2b2b..e464b32e61dd6 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer @@ -43,9 +44,8 @@ import org.apache.spark.util.{AkkaUtils, Utils} * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently - * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these - * objects needs to have the right SparkEnv set. You can get the current environment with - * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. + * Spark code finds the SparkEnv through a global variable, so all the threads can access the same + * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext). * * NOTE: This is not intended for external use. This is exposed for Shark and may be made private * in a future release. @@ -69,6 +69,7 @@ class SparkEnv ( val shuffleMemoryManager: ShuffleMemoryManager, val conf: SparkConf) extends Logging { + private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation @@ -76,6 +77,7 @@ class SparkEnv ( private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]() private[spark] def stop() { + isStopped = true pythonWorkers.foreach { case(key, worker) => worker.stop() } Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() @@ -119,40 +121,73 @@ class SparkEnv ( } object SparkEnv extends Logging { - private val env = new ThreadLocal[SparkEnv] - @volatile private var lastSetSparkEnv : SparkEnv = _ + @volatile private var env: SparkEnv = _ private[spark] val driverActorSystemName = "sparkDriver" private[spark] val executorActorSystemName = "sparkExecutor" def set(e: SparkEnv) { - lastSetSparkEnv = e - env.set(e) + env = e } /** - * Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv - * previously set in any thread. + * Returns the SparkEnv. */ def get: SparkEnv = { - Option(env.get()).getOrElse(lastSetSparkEnv) + env } /** * Returns the ThreadLocal SparkEnv. */ + @deprecated("Use SparkEnv.get instead", "1.2") def getThreadLocal: SparkEnv = { - env.get() + env } - private[spark] def create( + /** + * Create a SparkEnv for the driver. + */ + private[spark] def createDriverEnv( + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus): SparkEnv = { + assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!") + assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!") + val hostname = conf.get("spark.driver.host") + val port = conf.get("spark.driver.port").toInt + create(conf, SparkContext.DRIVER_IDENTIFIER, hostname, port, true, isLocal, listenerBus) + } + + /** + * Create a SparkEnv for an executor. + * In coarse-grained mode, the executor provides an actor system that is already instantiated. + */ + private[spark] def createExecutorEnv( + conf: SparkConf, + executorId: String, + hostname: String, + port: Int, + numCores: Int, + isLocal: Boolean, + actorSystem: ActorSystem = null): SparkEnv = { + create(conf, executorId, hostname, port, false, isLocal, defaultActorSystem = actorSystem, + numUsableCores = numCores) + } + + /** + * Helper method to create a SparkEnv for a driver or an executor. + */ + private def create( conf: SparkConf, executorId: String, hostname: String, port: Int, isDriver: Boolean, isLocal: Boolean, - listenerBus: LiveListenerBus = null): SparkEnv = { + listenerBus: LiveListenerBus = null, + defaultActorSystem: ActorSystem = null, + numUsableCores: Int = 0): SparkEnv = { // Listener bus is only used on the driver if (isDriver) { @@ -160,9 +195,16 @@ object SparkEnv extends Logging { } val securityManager = new SecurityManager(conf) - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - actorSystemName, hostname, port, conf, securityManager) + + // If an existing actor system is already provided, use it. + // This is the case when an executor is launched in coarse-grained mode. + val (actorSystem, boundPort) = + Option(defaultActorSystem) match { + case Some(as) => (as, port) + case None => + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) + } // Figure out which port Akka actually bound to in case the original port is 0 or occupied. // This is so that we tell the executors the correct port to connect to. @@ -234,14 +276,22 @@ object SparkEnv extends Logging { val shuffleMemoryManager = new ShuffleMemoryManager(conf) - val blockTransferService = new NioBlockTransferService(conf, securityManager) + val blockTransferService = + conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match { + case "netty" => + new NettyBlockTransferService(conf, securityManager, numUsableCores) + case "nio" => + new NioBlockTransferService(conf, securityManager) + } val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) + serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, + numUsableCores) val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index 376e69cd997d5..40237596570de 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD /** diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala deleted file mode 100644 index 65003b6ac6a0a..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.IOException -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.RealmCallback -import javax.security.sasl.RealmChoiceCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslClient -import javax.security.sasl.SaslException - -import scala.collection.JavaConversions.mapAsJavaMap - -/** - * Implements SASL Client logic for Spark - */ -private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging { - - /** - * Used to respond to server's counterpart, SaslServer with SASL tokens - * represented as byte arrays. - * - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), - null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslClientCallbackHandler(securityMgr)) - - /** - * Used to initiate SASL handshake with server. - * @return response to challenge if needed - */ - def firstToken(): Array[Byte] = { - synchronized { - val saslToken: Array[Byte] = - if (saslClient != null && saslClient.hasInitialResponse()) { - logDebug("has initial response") - saslClient.evaluateChallenge(new Array[Byte](0)) - } else { - new Array[Byte](0) - } - saslToken - } - } - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslClient != null) saslClient.isComplete() else false - } - } - - /** - * Respond to server's SASL token. - * @param saslTokenMessage contains server's SASL token - * @return client's response SASL token - */ - def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = { - synchronized { - if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslClient might be using. - */ - def dispose() { - synchronized { - if (saslClient != null) { - try { - saslClient.dispose() - } catch { - case e: SaslException => // ignored - } finally { - saslClient = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * that works with share secrets. - */ - private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends - CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8")) - private val secretKey = securityMgr.getSecretKey() - private val userPassword: Array[Char] = SparkSaslServer.encodePassword( - if (secretKey != null) secretKey.getBytes("utf-8") else "".getBytes("utf-8")) - - /** - * Implementation used to respond to SASL request from the server. - * - * @param callbacks objects that indicate what credential information the - * server's SaslServer requires from the client. - */ - override def handle(callbacks: Array[Callback]) { - logDebug("in the sasl client callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL client callback: setting username: " + userName) - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL client callback: setting userPassword") - pc.setPassword(userPassword) - } - case rc: RealmCallback => { - logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case cb: RealmChoiceCallback => {} - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala deleted file mode 100644 index f6b0a9132aca4..0000000000000 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ /dev/null @@ -1,174 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import javax.security.auth.callback.Callback -import javax.security.auth.callback.CallbackHandler -import javax.security.auth.callback.NameCallback -import javax.security.auth.callback.PasswordCallback -import javax.security.auth.callback.UnsupportedCallbackException -import javax.security.sasl.AuthorizeCallback -import javax.security.sasl.RealmCallback -import javax.security.sasl.Sasl -import javax.security.sasl.SaslException -import javax.security.sasl.SaslServer -import scala.collection.JavaConversions.mapAsJavaMap -import org.apache.commons.net.util.Base64 - -/** - * Encapsulates SASL server logic - */ -private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging { - - /** - * Actual SASL work done by this object from javax.security.sasl. - */ - private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null, - SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, - new SparkSaslDigestCallbackHandler(securityMgr)) - - /** - * Determines whether the authentication exchange has completed. - * @return true is complete, otherwise false - */ - def isComplete(): Boolean = { - synchronized { - if (saslServer != null) saslServer.isComplete() else false - } - } - - /** - * Used to respond to server SASL tokens. - * @param token Server's SASL token - * @return response to send back to the server. - */ - def response(token: Array[Byte]): Array[Byte] = { - synchronized { - if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0) - } - } - - /** - * Disposes of any system resources or security-sensitive information the - * SaslServer might be using. - */ - def dispose() { - synchronized { - if (saslServer != null) { - try { - saslServer.dispose() - } catch { - case e: SaslException => // ignore - } finally { - saslServer = null - } - } - } - } - - /** - * Implementation of javax.security.auth.callback.CallbackHandler - * for SASL DIGEST-MD5 mechanism - */ - private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) - extends CallbackHandler { - - private val userName: String = - SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8")) - - override def handle(callbacks: Array[Callback]) { - logDebug("In the sasl server callback handler") - callbacks foreach { - case nc: NameCallback => { - logDebug("handle: SASL server callback: setting username") - nc.setName(userName) - } - case pc: PasswordCallback => { - logDebug("handle: SASL server callback: setting userPassword") - val password: Array[Char] = - SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes("utf-8")) - pc.setPassword(password) - } - case rc: RealmCallback => { - logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText()) - rc.setText(rc.getDefaultText()) - } - case ac: AuthorizeCallback => { - val authid = ac.getAuthenticationID() - val authzid = ac.getAuthorizationID() - if (authid.equals(authzid)) { - logDebug("set auth to true") - ac.setAuthorized(true) - } else { - logDebug("set auth to false") - ac.setAuthorized(false) - } - if (ac.isAuthorized()) { - logDebug("sasl server is authorized") - ac.setAuthorizedID(authzid) - } - } - case cb: Callback => throw - new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") - } - } - } -} - -private[spark] object SparkSaslServer { - - /** - * This is passed as the server name when creating the sasl client/server. - * This could be changed to be configurable in the future. - */ - val SASL_DEFAULT_REALM = "default" - - /** - * The authentication mechanism used here is DIGEST-MD5. This could be changed to be - * configurable in the future. - */ - val DIGEST = "DIGEST-MD5" - - /** - * The quality of protection is just "auth". This means that we are doing - * authentication only, we are not supporting integrity or privacy protection of the - * communication channel after authentication. This could be changed to be configurable - * in the future. - */ - val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true") - - /** - * Encode a byte[] identifier as a Base64-encoded string. - * - * @param identifier identifier to encode - * @return Base64-encoded string - */ - def encodeIdentifier(identifier: Array[Byte]): String = { - new String(Base64.encodeBase64(identifier), "utf-8") - } - - /** - * Encode a password as a base64-encoded char[] array. - * @param password as a byte array. - * @return password as a char array. - */ - def encodePassword(password: Array[Byte]): Array[Char] = { - new String(Base64.encodeBase64(password), "utf-8").toCharArray() - } -} - diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala new file mode 100644 index 0000000000000..edbdda8a0bcb6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Low-level status reporting APIs for monitoring job and stage progress. + * + * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should + * be prepared to handle empty / missing information. For example, a job's stage ids may be known + * but the status API may not have any information about the details of those stages, so + * `getStageInfo` could potentially return `None` for a valid stage id. + * + * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs + * will provide information for the last `spark.ui.retainedStages` stages and + * `spark.ui.retainedJobs` jobs. + * + * NOTE: this class's constructor should be considered private and may be subject to change. + */ +class SparkStatusTracker private[spark] (sc: SparkContext) { + + private val jobProgressListener = sc.jobProgressListener + + /** + * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then + * returns all known jobs that are not associated with a job group. + * + * The returned list may contain running, failed, and completed jobs, and may vary across + * invocations of this method. This method does not guarantee the order of the elements in + * its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = { + jobProgressListener.synchronized { + val jobData = jobProgressListener.jobIdToData.valuesIterator + jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray + } + } + + /** + * Returns an array containing the ids of all active stages. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveStageIds(): Array[Int] = { + jobProgressListener.synchronized { + jobProgressListener.activeStages.values.map(_.stageId).toArray + } + } + + /** + * Returns an array containing the ids of all active jobs. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveJobIds(): Array[Int] = { + jobProgressListener.synchronized { + jobProgressListener.activeJobs.values.map(_.jobId).toArray + } + } + + /** + * Returns job information, or `None` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): Option[SparkJobInfo] = { + jobProgressListener.synchronized { + jobProgressListener.jobIdToData.get(jobId).map { data => + new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status) + } + } + } + + /** + * Returns stage information, or `None` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): Option[SparkStageInfo] = { + jobProgressListener.synchronized { + for ( + info <- jobProgressListener.stageIdToInfo.get(stageId); + data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId)) + ) yield { + new SparkStageInfoImpl( + stageId, + info.attemptId, + info.submissionTime.getOrElse(0), + info.name, + info.numTasks, + data.numActiveTasks, + data.numCompleteTasks, + data.numFailedTasks) + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala new file mode 100644 index 0000000000000..e5c7c8d0db578 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +private class SparkJobInfoImpl ( + val jobId: Int, + val stageIds: Array[Int], + val status: JobExecutionStatus) + extends SparkJobInfo + +private class SparkStageInfoImpl( + val stageId: Int, + val currentAttemptId: Int, + val submissionTime: Long, + val name: String, + val numTasks: Int, + val numActiveTasks: Int, + val numCompletedTasks: Int, + val numFailedTasks: Int) + extends SparkStageInfo diff --git a/core/src/main/scala/org/apache/spark/TaskContextHelper.scala b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala new file mode 100644 index 0000000000000..4636c4600a01a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextHelper.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * This class exists to restrict the visibility of TaskContext setters. + */ +private [spark] object TaskContextHelper { + + def setTaskContext(tc: TaskContext): Unit = TaskContext.setTaskContext(tc) + + def unset(): Unit = TaskContext.unset() + +} diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala new file mode 100644 index 0000000000000..afd2b85d33a77 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} + +import scala.collection.mutable.ArrayBuffer + +private[spark] class TaskContextImpl(val stageId: Int, + val partitionId: Int, + val attemptId: Long, + val runningLocally: Boolean = false, + val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContext + with Logging { + + // List of callback functions to execute when the task completes. + @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] + + // Whether the corresponding task has been killed. + @volatile private var interrupted: Boolean = false + + // Whether the task has completed. + @volatile private var completed: Boolean = false + + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { + onCompleteCallbacks += listener + this + } + + override def addTaskCompletionListener(f: TaskContext => Unit): this.type = { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f(context) + } + this + } + + @deprecated("use addTaskCompletionListener", "1.1.0") + override def addOnCompleteCallback(f: () => Unit) { + onCompleteCallbacks += new TaskCompletionListener { + override def onTaskCompletion(context: TaskContext): Unit = f() + } + } + + /** Marks the task as completed and triggers the listeners. */ + private[spark] def markTaskCompleted(): Unit = { + completed = true + val errorMsgs = new ArrayBuffer[String](2) + // Process complete callbacks in the reverse order of registration + onCompleteCallbacks.reverse.foreach { listener => + try { + listener.onTaskCompletion(this) + } catch { + case e: Throwable => + errorMsgs += e.getMessage + logError("Error in TaskCompletionListener", e) + } + } + if (errorMsgs.nonEmpty) { + throw new TaskCompletionListenerException(errorMsgs) + } + } + + /** Marks the task for interruption, i.e. cancellation. */ + private[spark] def markInterrupted(): Unit = { + interrupted = true + } + + override def isCompleted: Boolean = completed + + override def isRunningLocally: Boolean = runningLocally + + override def isInterrupted: Boolean = interrupted +} + diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 8f0c5e78416c2..af5fd8e0ac00c 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -69,11 +69,13 @@ case class FetchFailed( bmAddress: BlockManagerId, // Note that bmAddress can be null shuffleId: Int, mapId: Int, - reduceId: Int) + reduceId: Int, + message: String) extends TaskFailedReason { override def toErrorString: String = { val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString - s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)" + s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " + + s"message=\n$message\n)" } } @@ -81,15 +83,48 @@ case class FetchFailed( * :: DeveloperApi :: * Task failed due to a runtime exception. This is the most common failure case and also captures * user program exceptions. + * + * `stackTrace` contains the stack trace of the exception itself. It still exists for backward + * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to + * create `ExceptionFailure` as it will handle the backward compatibility properly. + * + * `fullStackTrace` is a better representation of the stack trace because it contains the whole + * stack trace including the exception and its causes */ @DeveloperApi case class ExceptionFailure( className: String, description: String, stackTrace: Array[StackTraceElement], + fullStackTrace: String, metrics: Option[TaskMetrics]) extends TaskFailedReason { - override def toErrorString: String = Utils.exceptionString(className, description, stackTrace) + + private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) { + this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics) + } + + override def toErrorString: String = + if (fullStackTrace == null) { + // fullStackTrace is added in 1.2.0 + // If fullStackTrace is null, use the old error string for backward compatibility + exceptionString(className, description, stackTrace) + } else { + fullStackTrace + } + + /** + * Return a nice string representation of the exception, including the stack trace. + * Note: It does not include the exception's causes, and is only used for backward compatibility. + */ + private def exceptionString( + className: String, + description: String, + stackTrace: Array[StackTraceElement]): String = { + val desc = if (description == null) "" else description + val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n") + s"$className: $desc\n$st" + } } /** @@ -117,8 +152,8 @@ case object TaskKilled extends TaskFailedReason { * the task crashed the JVM. */ @DeveloperApi -case object ExecutorLostFailure extends TaskFailedReason { - override def toErrorString: String = "ExecutorLostFailure (executor lost)" +case class ExecutorLostFailure(execId: String) extends TaskFailedReason { + override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)" } /** diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 8ca731038e528..34078142f5385 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -23,8 +23,10 @@ import java.util.jar.{JarEntry, JarOutputStream} import scala.collection.JavaConversions._ +import com.google.common.io.{ByteStreams, Files} import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} -import com.google.common.io.Files + +import org.apache.spark.util.Utils /** * Utilities for tests. Included in main codebase since it's used by multiple @@ -42,8 +44,7 @@ private[spark] object TestUtils { * in order to avoid interference between tests. */ def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) createJar(files, jarFile) @@ -63,12 +64,7 @@ private[spark] object TestUtils { jarStream.putNextEntry(jarEntry) val in = new FileInputStream(file) - val buffer = new Array[Byte](10240) - var nRead = 0 - while (nRead <= 0) { - nRead = in.read(buffer, 0, buffer.length) - jarStream.write(buffer, 0, nRead) - } + ByteStreams.copy(in, jarStream) in.close() } jarStream.close() diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index a6123bd108c11..8e8f7f6c4fda2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -114,7 +114,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.subtract(other)) @@ -233,11 +233,11 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja * to the left except for the last which is closed * e.g. for the array * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] - * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched - * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. * buckets array must be at least two elements diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 0846225e4f992..7af3538262fd6 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -32,12 +32,13 @@ import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.Partitioner._ -import org.apache.spark.SparkContext.rddToPairRDDFunctions import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.{OrderedRDDFunctions, RDD} +import org.apache.spark.rdd.RDD.rddToPairRDDFunctions import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * before sending results to a reducer, similarly to a "combiner" in MapReduce. */ def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] = - mapAsJavaMap(rdd.reduceByKeyLocally(func)) + mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func)) /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) + def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey()) /** * :: Experimental :: @@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) */ @Experimental def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap) /** * :: Experimental :: @@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[java.util.Map[K, BoundedDouble]] = - rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * Aggregate the values of each key, using given combine functions and a neutral "zero value". @@ -391,7 +392,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = fromRDD(rdd.subtract(other)) @@ -412,7 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Return an RDD with the pairs from `this` whose keys are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtractByKey[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, V] = { implicit val ctag: ClassTag[W] = fakeClassTag @@ -614,7 +615,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) /** * Return the key-value pairs in this RDD to the master as a Map. */ - def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap()) + def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap()) + /** * Pass each value in the key-value pair RDD through a map function without changing the keys; diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 545bc0e9e99ed..5a8e5bb1f721a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -21,15 +21,18 @@ import java.util.{Comparator, List => JList, Iterator => JIterator} import java.lang.{Iterable => JIterable, Long => JLong} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext} +import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD @@ -293,8 +296,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to all elements of this RDD. */ def foreach(f: VoidFunction[T]) { - val cleanF = rdd.context.clean((x: T) => f.call(x)) - rdd.foreach(cleanF) + rdd.foreach(x => f.call(x)) } /** @@ -390,7 +392,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, java.lang.Long] = - mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) + mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2))))) /** * (Experimental) Approximate version of countByValue(). @@ -399,13 +401,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { timeout: Long, confidence: Double ): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap) + rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap) /** * (Experimental) Approximate version of countByValue(). */ def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] = - rdd.countByValueApprox(timeout).map(mapAsJavaMap) + rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap) /** * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so @@ -491,9 +493,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the top K elements from this RDD as defined by + * Returns the top k (largest) elements from this RDD as defined by * the specified Comparator[T]. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @param comp the comparator that defines the order * @return an array of top elements */ @@ -505,9 +507,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the top K elements from this RDD using the + * Returns the top k (largest) elements from this RDD using the * natural ordering for T. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @return an array of top elements */ def top(num: Int): JList[T] = { @@ -516,9 +518,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the first K elements from this RDD as defined by + * Returns the first k (smallest) elements from this RDD as defined by * the specified Comparator[T] and maintains the order. - * @param num the number of top elements to return + * @param num k, the number of elements to return * @param comp the comparator that defines the order * @return an array of top elements */ @@ -550,9 +552,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Returns the first K elements from this RDD using the + * Returns the first k (smallest) elements from this RDD using the * natural ordering for T while maintain the order. - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @return an array of top elements */ def takeOrdered(num: Int): JList[T] = { @@ -575,16 +577,44 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def name(): String = rdd.name /** - * :: Experimental :: - * The asynchronous version of the foreach action. - * - * @param f the function to apply to all the elements of the RDD - * @return a FutureAction for the action + * The asynchronous version of `count`, which returns a + * future for counting the number of elements in this RDD. */ - @Experimental - def foreachAsync(f: VoidFunction[T]): FutureAction[Unit] = { - import org.apache.spark.SparkContext._ - rdd.foreachAsync(x => f.call(x)) + def countAsync(): JavaFutureAction[JLong] = { + new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf) + } + + /** + * The asynchronous version of `collect`, which returns a future for + * retrieving an array containing all of the elements in this RDD. + */ + def collectAsync(): JavaFutureAction[JList[T]] = { + new JavaFutureActionWrapper(rdd.collectAsync(), (x: Seq[T]) => x.asJava) + } + + /** + * The asynchronous version of the `take` action, which returns a + * future for retrieving the first `num` elements of this RDD. + */ + def takeAsync(num: Int): JavaFutureAction[JList[T]] = { + new JavaFutureActionWrapper(rdd.takeAsync(num), (x: Seq[T]) => x.asJava) } + /** + * The asynchronous version of the `foreach` action, which + * applies a function f to all the elements of this RDD. + */ + def foreachAsync(f: VoidFunction[T]): JavaFutureAction[Void] = { + new JavaFutureActionWrapper[Unit, Void](rdd.foreachAsync(x => f.call(x)), + { x => null.asInstanceOf[Void] }) + } + + /** + * The asynchronous version of the `foreachPartition` action, which + * applies a function f to each partition of this RDD. + */ + def foreachPartitionAsync(f: VoidFunction[java.util.Iterator[T]]): JavaFutureAction[Void] = { + new JavaFutureActionWrapper[Unit, Void](rdd.foreachPartitionAsync(x => f.call(x)), + { x => null.asInstanceOf[Void] }) + } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 791d853a015a1..97f5c9f257e09 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -28,11 +28,13 @@ import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration +import org.apache.spark.input.PortableDataStream import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ -import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam} +import org.apache.spark.AccumulatorParam._ +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} @@ -40,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns * [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones. + * + * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before + * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details. */ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround with Closeable { @@ -103,6 +108,8 @@ class JavaSparkContext(val sc: SparkContext) private[spark] val env = sc.env + def statusTracker = new JavaSparkStatusTracker(sc) + def isLocal: java.lang.Boolean = sc.isLocal def sparkUser: String = sc.sparkUser @@ -183,6 +190,8 @@ class JavaSparkContext(val sc: SparkContext) def textFile(path: String, minPartitions: Int): JavaRDD[String] = sc.textFile(path, minPartitions) + + /** * Read a directory of text files from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI. Each file is read as a single record and returned in a @@ -196,7 +205,10 @@ class JavaSparkContext(val sc: SparkContext) * hdfs://a-hdfs-path/part-nnnnn * }}} * - * Do `JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")`, + * Do + * {{{ + * JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path") + * }}} * *

then `rdd` contains * {{{ @@ -223,6 +235,84 @@ class JavaSparkContext(val sc: SparkContext) def wholeTextFiles(path: String): JavaPairRDD[String, String] = new JavaPairRDD(sc.wholeTextFiles(path)) + /** + * Read a directory of binary files from HDFS, a local file system (available on all nodes), + * or any Hadoop-supported file system URI as a byte array. Each file is read as a single + * record and returned in a key-value pair, where the key is the path of each file, + * the value is the content of each file. + * + * For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + * then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @note Small files are preferred; very large files but may cause bad performance. + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. + */ + def binaryFiles(path: String, minPartitions: Int): JavaPairRDD[String, PortableDataStream] = + new JavaPairRDD(sc.binaryFiles(path, minPartitions)) + + /** + * :: Experimental :: + * + * Read a directory of binary files from HDFS, a local file system (available on all nodes), + * or any Hadoop-supported file system URI as a byte array. Each file is read as a single + * record and returned in a key-value pair, where the key is the path of each file, + * the value is the content of each file. + * + * For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do + * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`, + * + * then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @note Small files are preferred; very large files but may cause bad performance. + */ + @Experimental + def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] = + new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions)) + + /** + * :: Experimental :: + * + * Load data from a flat binary file, assuming the length of each record is constant. + * + * @param path Directory to the input data files + * @return An RDD of data with values, represented as byte arrays + */ + @Experimental + def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = { + new JavaRDD(sc.binaryRecords(path, recordLength)) + } + /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala new file mode 100644 index 0000000000000..3300cad9efbab --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java + +import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext} + +/** + * Low-level status reporting APIs for monitoring job and stage progress. + * + * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should + * be prepared to handle empty / missing information. For example, a job's stage ids may be known + * but the status API may not have any information about the details of those stages, so + * `getStageInfo` could potentially return `null` for a valid stage id. + * + * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs + * will provide information for the last `spark.ui.retainedStages` stages and + * `spark.ui.retainedJobs` jobs. + * + * NOTE: this class's constructor should be considered private and may be subject to change. + */ +class JavaSparkStatusTracker private[spark] (sc: SparkContext) { + + /** + * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then + * returns all known jobs that are not associated with a job group. + * + * The returned list may contain running, failed, and completed jobs, and may vary across + * invocations of this method. This method does not guarantee the order of the elements in + * its result. + */ + def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.statusTracker.getJobIdsForGroup(jobGroup) + + /** + * Returns an array containing the ids of all active stages. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveStageIds(): Array[Int] = sc.statusTracker.getActiveStageIds() + + /** + * Returns an array containing the ids of all active jobs. + * + * This method does not guarantee the order of the elements in its result. + */ + def getActiveJobIds(): Array[Int] = sc.statusTracker.getActiveJobIds() + + /** + * Returns job information, or `null` if the job info could not be found or was garbage collected. + */ + def getJobInfo(jobId: Int): SparkJobInfo = sc.statusTracker.getJobInfo(jobId).orNull + + /** + * Returns stage information, or `null` if the stage info could not be found or was + * garbage collected. + */ + def getStageInfo(stageId: Int): SparkStageInfo = sc.statusTracker.getStageInfo(stageId).orNull +} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 22810cb1c662d..b52d0a5028e84 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -19,10 +19,20 @@ package org.apache.spark.api.java import com.google.common.base.Optional +import scala.collection.convert.Wrappers.MapWrapper + private[spark] object JavaUtils { def optionToOptional[T](option: Option[T]): Optional[T] = option match { case Some(value) => Optional.of(value) case None => Optional.absent() } + + // Workaround for SPARK-3926 / SI-8911 + def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) = + new SerializableMapWrapper(underlying) + + class SerializableMapWrapper[A, B](underlying: collection.Map[A, B]) + extends MapWrapper(underlying) with java.io.Serializable + } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 49dc95f349eac..5ba66178e2b78 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -61,8 +61,7 @@ private[python] object Converter extends Logging { * Other objects are passed through without conversion. */ private[python] class WritableToJavaConverter( - conf: Broadcast[SerializableWritable[Configuration]], - batchSize: Int) extends Converter[Any, Any] { + conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] { /** * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or @@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter( map.put(convertWritable(k), convertWritable(v)) } map - case w: Writable => - if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w + case w: Writable => WritableUtils.clone(w, conf.value.value) case other => other } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 924141475383d..e0bc00e1eb249 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,16 +19,15 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.nio.charset.Charset -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} +import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections} + +import org.apache.spark.input.PortableDataStream import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials -import scala.reflect.ClassTag -import scala.util.{Try, Success, Failure} -import net.razorvine.pickle.{Pickler, Unpickler} +import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec @@ -42,22 +41,22 @@ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils private[spark] class PythonRDD( - parent: RDD[_], + @transient parent: RDD[_], command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], + broadcastVars: JList[Broadcast[PythonBroadcast]], accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) - override def getPartitions = parent.partitions + override def getPartitions = firstParent.partitions - override val partitioner = if (preservePartitoning) parent.partitioner else None + override val partitioner = if (preservePartitoning) firstParent.partitioner else None override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis @@ -76,6 +75,7 @@ private[spark] class PythonRDD( var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() + writerThread.join() if (reuse_worker && complete_cleanly) { env.releasePythonWorker(pythonExec, envVars.toMap, worker) } else { @@ -134,7 +134,7 @@ private[spark] class PythonRDD( val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) - throw new PythonException(new String(obj, "utf-8"), + throw new PythonException(new String(obj, UTF_8), writerThread.exception.getOrElse(null)) case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still @@ -146,7 +146,9 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } - complete_cleanly = true + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + complete_cleanly = true + } null } } catch { @@ -155,6 +157,10 @@ private[spark] class PythonRDD( logDebug("Exception thrown after task interruption", e) throw new TaskKilledException + case e: Exception if env.isStopped => + logDebug("Exception thrown after context is stopped", e) + null // exit silently + case e: Exception if writerThread.exception.isDefined => logError("Python worker exited unexpectedly (crashed)", e) logError("This may have been caused by a prior exception:", writerThread.exception.get) @@ -196,7 +202,6 @@ private[spark] class PythonRDD( override def run(): Unit = Utils.logUncaughtExceptions { try { - SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // Partition index @@ -225,8 +230,7 @@ private[spark] class PythonRDD( if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) + PythonRDD.writeUTF(broadcast.value.path, dataOut) oldBids.add(broadcast.id) } } @@ -235,8 +239,9 @@ private[spark] class PythonRDD( dataOut.writeInt(command.length) dataOut.write(command) // Data values - PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() } catch { case e: Exception if context.isCompleted || context.isInterrupted => @@ -248,6 +253,11 @@ private[spark] class PythonRDD( // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e worker.shutdownOutput() + } finally { + // Release memory used by this thread for shuffles + env.shuffleMemoryManager.releaseMemoryForThisThread() + // Release memory used by this thread for unrolling blocks + env.blockManager.memoryStore.releaseUnrollMemoryForThisThread() } } } @@ -303,10 +313,10 @@ private object SpecialLengths { val END_OF_DATA_SECTION = -1 val PYTHON_EXCEPTION_THROWN = -2 val TIMING_DATA = -3 + val END_OF_STREAM = -4 } private[spark] object PythonRDD extends Logging { - val UTF8 = Charset.forName("UTF-8") // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() @@ -357,16 +367,8 @@ private[spark] object PythonRDD extends Logging { } } - def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = { - val file = new DataInputStream(new FileInputStream(filename)) - try { - val length = file.readInt() - val obj = new Array[Byte](length) - file.readFully(obj) - sc.broadcast(obj) - } finally { - file.close() - } + def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = { + sc.broadcast(new PythonBroadcast(path)) } def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { @@ -386,22 +388,33 @@ private[spark] object PythonRDD extends Logging { newIter.asInstanceOf[Iterator[String]].foreach { str => writeUTF(str, dataOut) } - case pair: Tuple2[_, _] => - pair._1 match { - case bytePair: Array[Byte] => - newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair => - dataOut.writeInt(pair._1.length) - dataOut.write(pair._1) - dataOut.writeInt(pair._2.length) - dataOut.write(pair._2) - } - case stringPair: String => - newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => - writeUTF(pair._1, dataOut) - writeUTF(pair._2, dataOut) - } - case other => - throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) + case stream: PortableDataStream => + newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream => + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, stream: PortableDataStream) => + newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach { + case (key, stream) => + writeUTF(key, dataOut) + val bytes = stream.toArray() + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case (key: String, value: String) => + newIter.asInstanceOf[Iterator[(String, String)]].foreach { + case (key, value) => + writeUTF(key, dataOut) + writeUTF(value, dataOut) + } + case (key: Array[Byte], value: Array[Byte]) => + newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach { + case (key, value) => + dataOut.writeInt(key.length) + dataOut.write(key) + dataOut.writeInt(value.length) + dataOut.write(value) } case other => throw new SparkException("Unexpected element type " + first.getClass) @@ -431,7 +444,7 @@ private[spark] object PythonRDD extends Logging { val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -457,7 +470,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -483,7 +496,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -526,7 +539,7 @@ private[spark] object PythonRDD extends Logging { Some(path), inputFormatClass, keyClass, valueClass, mergedConf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -552,7 +565,7 @@ private[spark] object PythonRDD extends Logging { None, inputFormatClass, keyClass, valueClass, conf) val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf)) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, - new WritableToJavaConverter(confBroadcasted, batchSize)) + new WritableToJavaConverter(confBroadcasted)) JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize)) } @@ -574,7 +587,7 @@ private[spark] object PythonRDD extends Logging { } def writeUTF(str: String, dataOut: DataOutputStream) { - val bytes = str.getBytes(UTF8) + val bytes = str.getBytes(UTF_8) dataOut.writeInt(bytes.length) dataOut.write(bytes) } @@ -735,107 +748,11 @@ private[spark] object PythonRDD extends Logging { converted.saveAsHadoopDataset(new JobConf(conf)) } } - - - /** - * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). - */ - @deprecated("PySpark does not use it anymore", "1.1") - def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - unpickle.loads(row) match { - // in case of objects are pickled in batch mode - case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap) - // not in batch mode - case obj: JMap[String @unchecked, _] => Seq(obj.toMap) - } - } - } - } - - /** - * Convert an RDD of serialized Python tuple to Array (no recursive conversions). - * It is only used by pyspark.sql. - */ - def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = { - - def toArray(obj: Any): Array[_] = { - obj match { - case objs: JArrayList[_] => - objs.toArray - case obj if obj.getClass.isArray => - obj.asInstanceOf[Array[_]].toArray - } - } - - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]].map(toArray) - } else { - Seq(toArray(obj)) - } - } - }.toJavaRDD() - } - - private class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { - private val pickle = new Pickler() - private var batch = 1 - private val buffer = new mutable.ArrayBuffer[Any] - - override def hasNext(): Boolean = iter.hasNext - - override def next(): Array[Byte] = { - while (iter.hasNext && buffer.length < batch) { - buffer += iter.next() - } - val bytes = pickle.dumps(buffer.toArray) - val size = bytes.length - // let 1M < size < 10M - if (size < 1024 * 1024) { - batch *= 2 - } else if (size > 1024 * 1024 * 10 && batch > 1) { - batch /= 2 - } - buffer.clear() - bytes - } - } - - /** - * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by - * PySpark. - */ - def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { - jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } - } - - /** - * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. - */ - def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { - pyRDD.rdd.mapPartitions { iter => - val unpickle = new Unpickler - iter.flatMap { row => - val obj = unpickle.loads(row) - if (batched) { - obj.asInstanceOf[JArrayList[_]] - } else { - Seq(obj) - } - } - }.toJavaRDD() - } } private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { - override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8) + override def call(arr: Array[Byte]) : String = new String(arr, UTF_8) } /** @@ -890,3 +807,49 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: } } } + +/** + * An Wrapper for Python Broadcast, which is written into disk by Python. It also will + * write the data into disk after deserialization, then Python can read it from disks. + */ +private[spark] class PythonBroadcast(@transient var path: String) extends Serializable { + + /** + * Read data from disks, then copy it to `out` + */ + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { + val in = new FileInputStream(new File(path)) + try { + Utils.copyStream(in, out) + } finally { + in.close() + } + } + + /** + * Write data into disk, using randomly generated name. + */ + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { + val dir = new File(Utils.getLocalDir(SparkEnv.get.conf)) + val file = File.createTempFile("broadcast", "", dir) + path = file.getAbsolutePath + val out = new FileOutputStream(file) + try { + Utils.copyStream(in, out) + } finally { + out.close() + } + } + + /** + * Delete the file once the object is GCed. + */ + override def finalize() { + if (!path.isEmpty) { + val file = new File(path) + if (file.exists()) { + file.delete() + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 71bdf0fe1b917..e314408c067e9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -108,10 +108,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) // Create and start the worker - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.worker")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.worker")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") val worker = pb.start() // Redirect worker stdout and stderr @@ -149,10 +151,12 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String try { // Create and start the daemon - val pb = new ProcessBuilder(Seq(pythonExec, "-u", "-m", "pyspark.daemon")) + val pb = new ProcessBuilder(Seq(pythonExec, "-m", "pyspark.daemon")) val workerEnv = pb.environment() workerEnv.putAll(envVars) workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() val in = new DataInputStream(daemon.getInputStream) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 7903457b17e13..a4153aaa926f8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -18,8 +18,13 @@ package org.apache.spark.api.python import java.nio.ByteOrder +import java.util.{ArrayList => JArrayList} + +import org.apache.spark.api.java.JavaRDD import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.Failure import scala.util.Try @@ -29,7 +34,7 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ -private[python] object SerDeUtil extends Logging { +private[spark] object SerDeUtil extends Logging { // Unpickle array.array generated by Python 2.6 class ArrayConstructor extends net.razorvine.pickle.objects.ArrayConstructor { // /* Description of types */ @@ -76,8 +81,84 @@ private[python] object SerDeUtil extends Logging { } } + private var initialized = false + // This should be called before trying to unpickle array.array from Python + // In cluster mode, this should be put in closure def initialize() = { - Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + synchronized{ + if (!initialized) { + Unpickler.registerConstructor("array", "array", new ArrayConstructor()) + initialized = true + } + } + } + initialize() + + + /** + * Convert an RDD of Java objects to Array (no recursive conversions). + * It is only used by pyspark.sql. + */ + def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = { + jrdd.rdd.map { + case objs: JArrayList[_] => + objs.toArray + case obj if obj.getClass.isArray => + obj.asInstanceOf[Array[_]].toArray + }.toJavaRDD() + } + + /** + * Choose batch size based on size of objects + */ + private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] { + private val pickle = new Pickler() + private var batch = 1 + private val buffer = new mutable.ArrayBuffer[Any] + + override def hasNext: Boolean = iter.hasNext + + override def next(): Array[Byte] = { + while (iter.hasNext && buffer.length < batch) { + buffer += iter.next() + } + val bytes = pickle.dumps(buffer.toArray) + val size = bytes.length + // let 1M < size < 10M + if (size < 1024 * 1024) { + batch *= 2 + } else if (size > 1024 * 1024 * 10 && batch > 1) { + batch /= 2 + } + buffer.clear() + bytes + } + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() } private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { @@ -119,17 +200,18 @@ private[python] object SerDeUtil extends Logging { */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { val (keyFailed, valueFailed) = checkPickle(rdd.first()) + rdd.mapPartitions { iter => - val pickle = new Pickler val cleaned = iter.map { case (k, v) => val key = if (keyFailed) k.toString else k val value = if (valueFailed) v.toString else v Array[Any](key, value) } - if (batchSize > 1) { - cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) + if (batchSize == 0) { + new AutoBatchedPickler(cleaned) } else { - cleaned.map(pickle.dumps(_)) + val pickle = new Pickler + cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched))) } } } @@ -137,35 +219,22 @@ private[python] object SerDeUtil extends Logging { /** * Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)]. */ - def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = { + def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = { def isPair(obj: Any): Boolean = { - Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) && + Option(obj.getClass.getComponentType).exists(!_.isPrimitive) && obj.asInstanceOf[Array[_]].length == 2 } - pyRDD.mapPartitions { iter => - val unpickle = new Unpickler - val unpickled = - if (batchSerialized) { - iter.flatMap { batch => - unpickle.loads(batch) match { - case objs: java.util.List[_] => collectionAsScalaIterable(objs) - case other => throw new SparkException( - s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD") - } - } - } else { - iter.map(unpickle.loads(_)) - } - unpickled.map { - case obj if isPair(obj) => - // we only accept (K, V) - val arr = obj.asInstanceOf[Array[_]] - (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) - case other => throw new SparkException( - s"RDD element of type ${other.getClass.getName} cannot be used") - } + + val rdd = pythonToJava(pyRDD, batched).rdd + rdd.first match { + case obj if isPair(obj) => + // we only accept (K, V) + case other => throw new SparkException( + s"RDD element of type ${other.getClass.getName} cannot be used") + } + rdd.map { obj => + val arr = obj.asInstanceOf[Array[_]] + (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V]) } } - } - diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index d11db978b842e..c0cbd28a845be 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -18,7 +18,8 @@ package org.apache.spark.api.python import java.io.{DataOutput, DataInput} -import java.nio.charset.Charset + +import com.google.common.base.Charsets.UTF_8 import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat @@ -136,7 +137,7 @@ object WriteInputFormatTestDataGenerator { sc.parallelize(intKeys).saveAsSequenceFile(intPath) sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath) sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath) - sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) } + sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) } ).saveAsSequenceFile(bytesPath) val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false)) sc.parallelize(bools).saveAsSequenceFile(boolPath) @@ -175,11 +176,11 @@ object WriteInputFormatTestDataGenerator { // Create test data for arbitrary custom writable TestWritable val testClass = Seq( - ("1", TestWritable("test1", 123, 54.0)), - ("2", TestWritable("test2", 456, 8762.3)), - ("1", TestWritable("test3", 123, 423.1)), - ("3", TestWritable("test56", 456, 423.5)), - ("2", TestWritable("test2", 123, 5435.2)) + ("1", TestWritable("test1", 1, 1.0)), + ("2", TestWritable("test2", 2, 2.3)), + ("3", TestWritable("test3", 3, 3.1)), + ("5", TestWritable("test56", 5, 5.5)), + ("4", TestWritable("test4", 4, 4.2)) ) val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) } rdd.saveAsNewAPIHadoopFile(classPath, diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 15fd30e65761d..a5ea478f231d7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -20,6 +20,8 @@ package org.apache.spark.broadcast import java.io.Serializable import org.apache.spark.SparkException +import org.apache.spark.Logging +import org.apache.spark.util.Utils import scala.reflect.ClassTag @@ -37,7 +39,7 @@ import scala.reflect.ClassTag * * {{{ * scala> val broadcastVar = sc.broadcast(Array(1, 2, 3)) - * broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c) + * broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0) * * scala> broadcastVar.value * res0: Array[Int] = Array(1, 2, 3) @@ -52,7 +54,7 @@ import scala.reflect.ClassTag * @param id A unique identifier for the broadcast variable. * @tparam T Type of the data contained in the broadcast variable. */ -abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { +abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging { /** * Flag signifying whether the broadcast variable is valid @@ -60,6 +62,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { */ @volatile private var _isValid = true + private var _destroySite = "" + /** Get the broadcasted value. */ def value: T = { assertValid() @@ -84,13 +88,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { doUnpersist(blocking) } + + /** + * Destroy all data and metadata related to this broadcast variable. Use this with caution; + * once a broadcast variable has been destroyed, it cannot be used again. + * This method blocks until destroy has completed + */ + def destroy() { + destroy(blocking = true) + } + /** * Destroy all data and metadata related to this broadcast variable. Use this with caution; * once a broadcast variable has been destroyed, it cannot be used again. + * @param blocking Whether to block until destroy has completed */ private[spark] def destroy(blocking: Boolean) { assertValid() _isValid = false + _destroySite = Utils.getCallSite().shortForm + logInfo("Destroying %s (from %s)".format(toString, _destroySite)) doDestroy(blocking) } @@ -124,7 +141,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable { /** Check if this broadcast is valid. If not valid, exception is thrown. */ protected def assertValid() { if (!_isValid) { - throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) + throw new SparkException( + "Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite)) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 942dc7d7eac87..31f0a462f84d8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -72,13 +72,13 @@ private[spark] class HttpBroadcast[T: ClassTag]( } /** Used by the JVM when serializing this object. */ - private def writeObject(out: ObjectOutputStream) { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { assertValid() out.defaultWriteObject() } /** Used by the JVM when deserializing this object. */ - private def readObject(in: ObjectInputStream) { + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() HttpBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(blockId) match { @@ -163,18 +163,23 @@ private[broadcast] object HttpBroadcast extends Logging { private def write(id: Long, value: Any) { val file = getFile(id) - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(new FileOutputStream(file)) - } else { - new BufferedOutputStream(new FileOutputStream(file), bufferSize) + val fileOutputStream = new FileOutputStream(file) + try { + val out: OutputStream = { + if (compress) { + compressionCodec.compressedOutputStream(fileOutputStream) + } else { + new BufferedOutputStream(fileOutputStream, bufferSize) + } } + val ser = SparkEnv.get.serializer.newInstance() + val serOut = ser.serializeStream(out) + serOut.writeObject(value) + serOut.close() + files += file + } finally { + fileOutputStream.close() } - val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(out) - serOut.writeObject(value) - serOut.close() - files += file } private def read[T: ClassTag](id: Long): T = { @@ -186,10 +191,12 @@ private[broadcast] object HttpBroadcast extends Logging { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) uc = newuri.toURL.openConnection() + uc.setConnectTimeout(httpReadTimeout) uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") uc = new URL(url).openConnection() + uc.setConnectTimeout(httpReadTimeout) } val in = { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 42d58682a1e23..94142d33369c7 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -26,8 +26,9 @@ import scala.util.Random import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec +import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.{ByteBufferInputStream, Utils} import org.apache.spark.util.io.ByteArrayChunkOutputStream /** @@ -46,53 +47,66 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream * This prevents the driver from being the bottleneck in sending out multiple copies of the * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]]. * + * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. + * * @param obj object to broadcast - * @param isLocal whether Spark is running in local mode (single JVM process). * @param id A unique identifier for the broadcast variable. */ -private[spark] class TorrentBroadcast[T: ClassTag]( - obj : T, - @transient private val isLocal: Boolean, - id: Long) +private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) extends Broadcast[T](id) with Logging with Serializable { /** - * Value of the broadcast object. On driver, this is set directly by the constructor. - * On executors, this is reconstructed by [[readObject]], which builds this value by reading - * blocks from the driver and/or other executors. + * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]], + * which builds this value by reading blocks from the driver and/or other executors. + * + * On the driver, if the value is required, it is read lazily from the block manager. */ - @transient private var _value: T = obj + @transient private lazy val _value: T = readBroadcastBlock() + + /** The compression codec to use, or None if compression is disabled */ + @transient private var compressionCodec: Option[CompressionCodec] = _ + /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ + @transient private var blockSize: Int = _ + + private def setConf(conf: SparkConf) { + compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) { + Some(CompressionCodec.createCodec(conf)) + } else { + None + } + blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 + } + setConf(SparkEnv.get.conf) private val broadcastId = BroadcastBlockId(id) /** Total number of blocks this broadcast variable contains. */ - private val numBlocks: Int = writeBlocks() + private val numBlocks: Int = writeBlocks(obj) - override protected def getValue() = _value + override protected def getValue() = { + _value + } /** * Divide the object into multiple blocks and put those blocks in the block manager. - * + * @param value the object to divide * @return number of blocks this broadcast variable is divided into */ - private def writeBlocks(): Int = { - // For local mode, just put the object in the BlockManager so we can find it later. - SparkEnv.get.blockManager.putSingle( - broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - - if (!isLocal) { - val blocks = TorrentBroadcast.blockifyObject(_value) - blocks.zipWithIndex.foreach { case (block, i) => - SparkEnv.get.blockManager.putBytes( - BroadcastBlockId(id, "piece" + i), - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) - } - blocks.length - } else { - 0 + private def writeBlocks(value: T): Int = { + // Store a copy of the broadcast variable in the driver so that tasks run on the driver + // do not create a duplicate copy of the broadcast variable's value. + SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK, + tellMaster = false) + val blocks = + TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec) + blocks.zipWithIndex.foreach { case (block, i) => + SparkEnv.get.blockManager.putBytes( + BroadcastBlockId(id, "piece" + i), + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) } + blocks.length } /** Fetch torrent blocks from the driver and/or other executors. */ @@ -104,29 +118,24 @@ private[spark] class TorrentBroadcast[T: ClassTag]( for (pid <- Random.shuffle(Seq.range(0, numBlocks))) { val pieceId = BroadcastBlockId(id, "piece" + pid) - - // First try getLocalBytes because there is a chance that previous attempts to fetch the + logDebug(s"Reading piece $pieceId of $broadcastId") + // First try getLocalBytes because there is a chance that previous attempts to fetch the // broadcast blocks have already fetched some of the blocks. In that case, some blocks // would be available locally (on this executor). - var blockOpt = bm.getLocalBytes(pieceId) - if (!blockOpt.isDefined) { - blockOpt = bm.getRemoteBytes(pieceId) - blockOpt match { - case Some(block) => - // If we found the block from remote executors/driver's BlockManager, put the block - // in this executor's BlockManager. - SparkEnv.get.blockManager.putBytes( - pieceId, - block, - StorageLevel.MEMORY_AND_DISK_SER, - tellMaster = true) - - case None => - throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) - } + def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId) + def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block => + // If we found the block from remote executors/driver's BlockManager, put the block + // in this executor's BlockManager. + SparkEnv.get.blockManager.putBytes( + pieceId, + block, + StorageLevel.MEMORY_AND_DISK_SER, + tellMaster = true) + block } - // If we get here, the option is defined. - blocks(pid) = blockOpt.get + val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse( + throw new SparkException(s"Failed to get $pieceId of $broadcastId")) + blocks(pid) = block } blocks } @@ -147,75 +156,62 @@ private[spark] class TorrentBroadcast[T: ClassTag]( } /** Used by the JVM when serializing this object. */ - private def writeObject(out: ObjectOutputStream) { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { assertValid() out.defaultWriteObject() } - /** Used by the JVM when deserializing this object. */ - private def readObject(in: ObjectInputStream) { - in.defaultReadObject() + private def readBroadcastBlock(): T = Utils.tryOrIOException { TorrentBroadcast.synchronized { + setConf(SparkEnv.get.conf) SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match { case Some(x) => - _value = x.asInstanceOf[T] + x.asInstanceOf[T] case None => logInfo("Started reading broadcast variable " + id) - val start = System.nanoTime() + val startTimeMs = System.currentTimeMillis() val blocks = readBlocks() - val time = (System.nanoTime() - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") + logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs)) - _value = TorrentBroadcast.unBlockifyObject[T](blocks) + val obj = TorrentBroadcast.unBlockifyObject[T]( + blocks, SparkEnv.get.serializer, compressionCodec) // Store the merged copy in BlockManager so other tasks on this executor don't // need to re-fetch it. SparkEnv.get.blockManager.putSingle( - broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + obj } } } + } private object TorrentBroadcast extends Logging { - /** Size of each block. Default value is 4MB. */ - private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - private var initialized = false - private var conf: SparkConf = null - private var compress: Boolean = false - private var compressionCodec: CompressionCodec = null - - def initialize(_isDriver: Boolean, conf: SparkConf) { - TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests - synchronized { - if (!initialized) { - compress = conf.getBoolean("spark.broadcast.compress", true) - compressionCodec = CompressionCodec.createCodec(conf) - initialized = true - } - } - } - def stop() { - initialized = false - } - - def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = { - val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE) - val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos - val ser = SparkEnv.get.serializer.newInstance() + def blockifyObject[T: ClassTag]( + obj: T, + blockSize: Int, + serializer: Serializer, + compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = { + val bos = new ByteArrayChunkOutputStream(blockSize) + val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos) + val ser = serializer.newInstance() val serOut = ser.serializeStream(out) serOut.writeObject[T](obj).close() bos.toArrays.map(ByteBuffer.wrap) } - def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = { + def unBlockifyObject[T: ClassTag]( + blocks: Array[ByteBuffer], + serializer: Serializer, + compressionCodec: Option[CompressionCodec]): T = { + require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks") val is = new SequenceInputStream( asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block)))) - val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is - - val ser = SparkEnv.get.serializer.newInstance() + val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is) + val ser = serializer.newInstance() val serIn = ser.deserializeStream(in) val obj = serIn.readObject[T]() serIn.close() @@ -227,6 +223,7 @@ private object TorrentBroadcast extends Logging { * If removeFromDriver is true, also remove these persisted blocks on the driver. */ def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = { + logDebug(s"Unpersisting TorrentBroadcast $id") SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index ad0f701d7a98f..fb024c12094f2 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -28,14 +28,13 @@ import org.apache.spark.{SecurityManager, SparkConf} */ class TorrentBroadcastFactory extends BroadcastFactory { - override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - TorrentBroadcast.initialize(isDriver, conf) - } + override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = - new TorrentBroadcast[T](value_, isLocal, id) + override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = { + new TorrentBroadcast[T](value_, id) + } - override def stop() { TorrentBroadcast.stop() } + override def stop() { } /** * Remove all persisted state associated with the torrent broadcast with the given ID. diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 065ddda50e65e..f2687ce6b42b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -130,7 +130,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") System.exit(-1) - case AssociationErrorEvent(cause, _, remoteAddress, _) => + case AssociationErrorEvent(cause, _, remoteAddress, _, _) => println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") println(s"Cause was: $cause") System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 39150deab863c..2e1e52906ceeb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy +import java.net.{URI, URISyntaxException} + import scala.collection.mutable.ListBuffer import org.apache.log4j.Level @@ -73,7 +75,8 @@ private[spark] class ClientArguments(args: Array[String]) { if (!ClientArguments.isValidJarUrl(_jarUrl)) { println(s"Jar url '${_jarUrl}' is not in valid format.") - println(s"Must be a jar file path in URL format (e.g. hdfs://XX.jar, file://XX.jar)") + println(s"Must be a jar file path in URL format " + + "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)") printUsageAndExit(-1) } @@ -114,5 +117,12 @@ private[spark] class ClientArguments(args: Array[String]) { } object ClientArguments { - def isValidJarUrl(s: String): Boolean = s.matches("(.+):(.+)jar") + def isValidJarUrl(s: String): Boolean = { + try { + val uri = new URI(s) + uri.getScheme != null && uri.getPath != null && uri.getPath.endsWith(".jar") + } catch { + case _: URISyntaxException => false + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index a7368f9f3dfbe..c46f84de8444a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -71,6 +71,8 @@ private[deploy] object DeployMessages { case class RegisterWorkerFailed(message: String) extends DeployMessage + case class ReconnectWorker(masterUrl: String) extends DeployMessage + case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage case class LaunchExecutor( @@ -90,6 +92,8 @@ private[deploy] object DeployMessages { case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object ReregisterWithMaster // used when a worker attempts to reconnect to a master + // AppClient to Master case class RegisterApplication(appDescription: ApplicationDescription) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 79b4d7ea41a33..039c8719e2867 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -34,7 +34,8 @@ object PythonRunner { val pythonFile = args(0) val pyFiles = args(1) val otherArgs = args.slice(2, args.length) - val pythonExec = sys.env.get("PYSPARK_PYTHON").getOrElse("python") // TODO: get this from conf + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python")) // Format python file paths before adding them to the PYTHONPATH val formattedPythonFile = formatPath(pythonFile) @@ -57,6 +58,7 @@ object PythonRunner { val builder = new ProcessBuilder(Seq(pythonExec, formattedPythonFile) ++ otherArgs) val env = builder.environment() env.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize @@ -85,8 +87,8 @@ object PythonRunner { // Strip the URI scheme from the path formattedPath = new URI(formattedPath).getScheme match { - case Utils.windowsDrive(d) if windows => formattedPath case null => formattedPath + case Utils.windowsDrive(d) if windows => formattedPath case _ => new URI(formattedPath).getPath } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index fe0ad9ebbca12..60ee115e393ce 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -17,15 +17,19 @@ package org.apache.spark.deploy +import java.lang.reflect.Method import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Utils import scala.collection.JavaConversions._ @@ -121,6 +125,64 @@ class SparkHadoopUtil extends Logging { UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) } + /** + * Returns a function that can be called to find Hadoop FileSystem bytes read. If + * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will + * return the bytes read on r since t. Reflection is required because thread-level FileSystem + * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). + * Returns None if the required method can't be found. + */ + private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) + : Option[() => Long] = { + try { + val threadStats = getFileSystemThreadStatistics(path, conf) + val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") + val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum + val baselineBytesRead = f() + Some(() => f() - baselineBytesRead) + } catch { + case e: NoSuchMethodException => { + logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e) + None + } + } + } + + /** + * Returns a function that can be called to find Hadoop FileSystem bytes written. If + * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will + * return the bytes written on r since t. Reflection is required because thread-level FileSystem + * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). + * Returns None if the required method can't be found. + */ + private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) + : Option[() => Long] = { + try { + val threadStats = getFileSystemThreadStatistics(path, conf) + val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") + val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum + val baselineBytesWritten = f() + Some(() => f() - baselineBytesWritten) + } catch { + case e: NoSuchMethodException => { + logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) + None + } + } + } + + private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { + val qualifiedPath = path.getFileSystem(conf).makeQualified(path) + val scheme = qualifiedPath.toUri().getScheme() + val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) + } + + private def getFileSystemThreadStatisticsMethod(methodName: String): Method = { + val statisticsDataClass = + Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData") + statisticsDataClass.getDeclaredMethod(methodName) + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index f97bf67fa5a3b..00f291823e984 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -158,8 +158,9 @@ object SparkSubmit { args.files = mergeFileLists(args.files, args.primaryResource) } args.files = mergeFileLists(args.files, args.pyFiles) - // Format python file paths properly before adding them to the PYTHONPATH - sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",") + if (args.pyFiles != null) { + sysProps("spark.submit.pyFiles") = args.pyFiles + } } // Special flag to avoid deprecation warnings at the client @@ -273,15 +274,32 @@ object SparkSubmit { } } - // Properties given with --conf are superceded by other options, but take precedence over - // properties in the defaults file. + // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { sysProps.getOrElseUpdate(k, v) } - // Read from default spark properties, if any - for ((k, v) <- args.defaultSparkProperties) { - sysProps.getOrElseUpdate(k, v) + // Resolve paths in certain spark properties + val pathConfigs = Seq( + "spark.jars", + "spark.files", + "spark.yarn.jar", + "spark.yarn.dist.files", + "spark.yarn.dist.archives") + pathConfigs.foreach { config => + // Replace old URIs with resolved URIs, if they exist + sysProps.get(config).foreach { oldValue => + sysProps(config) = Utils.resolveURIs(oldValue) + } + } + + // Resolve and format python file paths properly before adding them to the PYTHONPATH. + // The resolving part is redundant in the case of --py-files, but necessary if the user + // explicitly sets `spark.submit.pyFiles` in his/her default properties file. + sysProps.get("spark.submit.pyFiles").foreach { pyFiles => + val resolvedPyFiles = Utils.resolveURIs(pyFiles) + val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",") + sysProps("spark.submit.pyFiles") = formattedPyFiles } (childArgs, childClasspath, sysProps, childMainClass) @@ -322,11 +340,16 @@ object SparkSubmit { e.printStackTrace(printStream) if (childMainClass.contains("thriftserver")) { println(s"Failed to load main class $childMainClass.") - println("You need to build Spark with -Phive.") + println("You need to build Spark with -Phive and -Phive-thriftserver.") } System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } + // SPARK-4170 + if (classOf[scala.App].isAssignableFrom(mainClass)) { + printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + } + val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) if (!Modifier.isStatic(mainMethod.getModifiers)) { throw new IllegalStateException("The main method in the given main class must be static") diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 57b251ff47714..f0e9ee67f6a67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -17,14 +17,10 @@ package org.apache.spark.deploy -import java.io.{File, FileInputStream, IOException} -import java.util.Properties import java.util.jar.JarFile -import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.spark.SparkException import org.apache.spark.util.Utils /** @@ -63,9 +59,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St val defaultProperties = new HashMap[String, String]() if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => - val file = new File(filename) - SparkSubmitArguments.getPropertiesFromFile(file).foreach { case (k, v) => - if (k.startsWith("spark")) { + Utils.getPropertiesFromFile(filename).foreach { case (k, v) => + if (k.startsWith("spark.")) { defaultProperties(k) = v if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") } else { @@ -76,51 +71,54 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St defaultProperties } - // Respect SPARK_*_MEMORY for cluster mode - driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull - executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull - + // Set parameters from command line arguments parseOpts(args.toList) - mergeSparkProperties() + // Populate `sparkProperties` map from properties file + mergeDefaultSparkProperties() + // Use `sparkProperties` map along with env vars to fill in any missing parameters + loadEnvironmentArguments() + checkRequiredArguments() /** - * Fill in any undefined values based on the default properties file or options passed in through - * the '--conf' flag. + * Merge values from the default properties file with those specified through --conf. + * When this is called, `sparkProperties` is already filled with configs from the latter. */ - private def mergeSparkProperties(): Unit = { + private def mergeDefaultSparkProperties(): Unit = { // Use common defaults file, if not specified by user - if (propertiesFile == null) { - val sep = File.separator - val sparkHomeConfig = env.get("SPARK_HOME").map(sparkHome => s"${sparkHome}${sep}conf") - val confDir = env.get("SPARK_CONF_DIR").orElse(sparkHomeConfig) - - confDir.foreach { sparkConfDir => - val defaultPath = s"${sparkConfDir}${sep}spark-defaults.conf" - val file = new File(defaultPath) - if (file.exists()) { - propertiesFile = file.getAbsolutePath - } + propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env)) + // Honor --conf before the defaults file + defaultSparkProperties.foreach { case (k, v) => + if (!sparkProperties.contains(k)) { + sparkProperties(k) = v } } + } - val properties = HashMap[String, String]() - properties.putAll(defaultSparkProperties) - properties.putAll(sparkProperties) - - // Use properties file as fallback for values which have a direct analog to - // arguments in this script. - master = Option(master).orElse(properties.get("spark.master")).orNull - executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull - executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull + /** + * Load arguments from environment variables, Spark properties etc. + */ + private def loadEnvironmentArguments(): Unit = { + master = Option(master) + .orElse(sparkProperties.get("spark.master")) + .orElse(env.get("MASTER")) + .orNull + driverMemory = Option(driverMemory) + .orElse(sparkProperties.get("spark.driver.memory")) + .orElse(env.get("SPARK_DRIVER_MEMORY")) + .orNull + executorMemory = Option(executorMemory) + .orElse(sparkProperties.get("spark.executor.memory")) + .orElse(env.get("SPARK_EXECUTOR_MEMORY")) + .orNull + executorCores = Option(executorCores) + .orElse(sparkProperties.get("spark.executor.cores")) + .orNull totalExecutorCores = Option(totalExecutorCores) - .orElse(properties.get("spark.cores.max")) + .orElse(sparkProperties.get("spark.cores.max")) .orNull - name = Option(name).orElse(properties.get("spark.app.name")).orNull - jars = Option(jars).orElse(properties.get("spark.jars")).orNull - - // This supports env vars in older versions of Spark - master = Option(master).orElse(env.get("MASTER")).orNull + name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull + jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull // Try to set main class from JAR if no --class argument is given @@ -147,7 +145,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ - private def checkRequiredArguments() = { + private def checkRequiredArguments(): Unit = { if (args.length == 0) { printUsageAndExit(-1) } @@ -182,7 +180,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - override def toString = { + override def toString = { s"""Parsed arguments: | master $master | deployMode $deployMode @@ -190,7 +188,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | executorCores $executorCores | totalExecutorCores $totalExecutorCores | propertiesFile $propertiesFile - | extraSparkProperties $sparkProperties | driverMemory $driverMemory | driverCores $driverCores | driverExtraClassPath $driverExtraClassPath @@ -209,8 +206,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | jars $jars | verbose $verbose | - |Default properties from $propertiesFile: - |${defaultSparkProperties.mkString(" ", "\n ", "\n")} + |Spark properties used, including those specified through + | --conf and those from the properties file $propertiesFile: + |${sparkProperties.mkString(" ", "\n ", "\n")} """.stripMargin } @@ -343,7 +341,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } - private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) { + private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = { val outStream = SparkSubmit.printStream if (unknownParam != null) { outStream.println("Unknown/unsupported param " + unknownParam) @@ -397,23 +395,3 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St SparkSubmit.exitFn() } } - -object SparkSubmitArguments { - /** Load properties present in the given file. */ - def getPropertiesFromFile(file: File): Seq[(String, String)] = { - require(file.exists(), s"Properties file $file does not exist") - require(file.isFile(), s"Properties file $file is not a normal file") - val inputStream = new FileInputStream(file) - try { - val properties = new Properties() - properties.load(inputStream) - properties.stringPropertyNames().toSeq.map(k => (k, properties(k).trim)) - } catch { - case e: IOException => - val message = s"Failed when loading Spark properties file $file" - throw new SparkException(message, e) - } finally { - inputStream.close() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala index a64170a47bc1c..d2687faad62b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala @@ -68,7 +68,7 @@ private[spark] object SparkSubmitDriverBootstrapper { assume(bootstrapDriver != null, "SPARK_SUBMIT_BOOTSTRAP_DRIVER must be set") // Parse the properties file for the equivalent spark.driver.* configs - val properties = SparkSubmitArguments.getPropertiesFromFile(new File(propertiesFile)).toMap + val properties = Utils.getPropertiesFromFile(propertiesFile) val confDriverMemory = properties.get("spark.driver.memory") val confLibraryPath = properties.get("spark.driver.extraLibraryPath") val confClasspath = properties.get("spark.driver.extraClassPath") @@ -82,17 +82,8 @@ private[spark] object SparkSubmitDriverBootstrapper { .orElse(confDriverMemory) .getOrElse(defaultDriverMemory) - val newLibraryPath = - if (submitLibraryPath.isDefined) { - // SPARK_SUBMIT_LIBRARY_PATH is already captured in JAVA_OPTS - "" - } else { - confLibraryPath.map("-Djava.library.path=" + _).getOrElse("") - } - val newClasspath = if (submitClasspath.isDefined) { - // SPARK_SUBMIT_CLASSPATH is already captured in CLASSPATH classpath } else { classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("") @@ -114,7 +105,6 @@ private[spark] object SparkSubmitDriverBootstrapper { val command: Seq[String] = Seq(runner) ++ Seq("-cp", newClasspath) ++ - Seq(newLibraryPath) ++ filteredJavaOpts ++ Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++ Seq("org.apache.spark.deploy.SparkSubmit") ++ @@ -130,8 +120,25 @@ private[spark] object SparkSubmitDriverBootstrapper { // Start the driver JVM val filteredCommand = command.filter(_.nonEmpty) val builder = new ProcessBuilder(filteredCommand) + val env = builder.environment() + + if (submitLibraryPath.isEmpty && confLibraryPath.nonEmpty) { + val libraryPaths = confLibraryPath ++ sys.env.get(Utils.libraryPathEnvName) + env.put(Utils.libraryPathEnvName, libraryPaths.mkString(sys.props("path.separator"))) + } + val process = builder.start() + // If we kill an app while it's running, its sub-process should be killed too. + Runtime.getRuntime().addShutdownHook(new Thread() { + override def run() = { + if (process != null) { + process.destroy() + process.waitFor() + } + } + }) + // Redirect stdout and stderr from the child JVM val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout") val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr") @@ -142,14 +149,15 @@ private[spark] object SparkSubmitDriverBootstrapper { // subprocess there already reads directly from our stdin, so we should avoid spawning a // thread that contends with the subprocess in reading from System.in. val isWindows = Utils.isWindows - val isPySparkShell = sys.env.contains("PYSPARK_SHELL") + val isSubprocess = sys.env.contains("IS_SUBPROCESS") if (!isWindows) { val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin") stdinThread.start() - // For the PySpark shell, Spark submit itself runs as a python subprocess, and so this JVM - // should terminate on broken pipe, which signals that the parent process has exited. In - // Windows, the termination logic for the PySpark shell is handled in java_gateway.py - if (isPySparkShell) { + // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on + // broken pipe, signaling that the parent process has exited. This is the case if the + // application is launched directly from python, as in the PySpark shell. In Windows, + // the termination logic is handled in java_gateway.py + if (isSubprocess) { stdinThread.join() process.destroy() } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 32790053a6be8..98a93d1fcb2a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -154,7 +154,7 @@ private[spark] class AppClient( logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() - case AssociationErrorEvent(cause, _, address, _) if isPossibleMaster(address) => + case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => logWarning(s"Could not connect to $address: $cause") case StopAppClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 481f6c93c6a8d..82a54dbfb5330 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -29,22 +29,27 @@ import org.apache.spark.scheduler._ import org.apache.spark.ui.SparkUI import org.apache.spark.util.Utils +/** + * A class that provides application history from event logs stored in the file system. + * This provider checks for new finished applications in the background periodically and + * renders the history application UI by parsing the associated event logs. + */ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider with Logging { + import FsHistoryProvider._ + private val NOT_STARTED = "" // Interval between each check for event log updates private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval", conf.getInt("spark.history.updateInterval", 10)) * 1000 - private val logDir = conf.get("spark.history.fs.logDirectory", null) - private val resolvedLogDir = Option(logDir) - .map { d => Utils.resolveURI(d) } - .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") } + private val logDir = conf.getOption("spark.history.fs.logDirectory") + .map { d => Utils.resolveURI(d).toString } + .getOrElse(DEFAULT_LOG_DIR) - private val fs = Utils.getHadoopFileSystem(resolvedLogDir, - SparkHadoopUtil.get.newConfiguration(conf)) + private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf)) // A timestamp of when the disk was last accessed to check for log updates private var lastLogCheckTimeMs = -1L @@ -87,14 +92,17 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private def initialize() { // Validate the log directory. - val path = new Path(resolvedLogDir) + val path = new Path(logDir) if (!fs.exists(path)) { - throw new IllegalArgumentException( - "Logging directory specified does not exist: %s".format(resolvedLogDir)) + var msg = s"Log directory specified does not exist: $logDir." + if (logDir == DEFAULT_LOG_DIR) { + msg += " Did you configure the correct one through spark.fs.history.logDirectory?" + } + throw new IllegalArgumentException(msg) } if (!fs.getFileStatus(path).isDir) { throw new IllegalArgumentException( - "Logging directory specified is not a directory: %s".format(resolvedLogDir)) + "Logging directory specified is not a directory: %s".format(logDir)) } checkForLogs() @@ -112,7 +120,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis val ui = { val conf = this.conf.clone() val appSecManager = new SecurityManager(conf) - new SparkUI(conf, appSecManager, replayBus, appId, + SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId, s"${HistoryServer.UI_PATH_PREFIX}/$appId") // Do not call ui.bind() to avoid creating a new server for each application } @@ -134,8 +142,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } } - override def getConfig(): Map[String, String] = - Map("Event Log Location" -> resolvedLogDir.toString) + override def getConfig(): Map[String, String] = Map("Event log directory" -> logDir.toString) /** * Builds the application list based on the current contents of the log directory. @@ -146,7 +153,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis lastLogCheckTimeMs = getMonotonicTimeMs() logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs)) try { - val logStatus = fs.listStatus(new Path(resolvedLogDir)) + val logStatus = fs.listStatus(new Path(logDir)) val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() // Load all new logs from the log directory. Only directories that have a modification time @@ -244,6 +251,10 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis } +private object FsHistoryProvider { + val DEFAULT_LOG_DIR = "file:/tmp/spark-events" +} + private class FsApplicationHistoryInfo( val logDir: String, id: String, diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index d25c29113d6da..5fdc350cd8512 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -58,7 +58,13 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { ++ appTable } else { -

No Completed Applications Found

+

No completed applications found!

++ +

Did you specify the correct logging directory? + Please verify your setting of + spark.history.fs.logDirectory and whether you have the permissions to + access it.
It is also possible that your application did not run to + completion or did not stop the SparkContext. +

} } @@ -84,11 +90,11 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { {info.id} {info.name} - {startTime} - {endTime} - {duration} + {startTime} + {endTime} + {duration} {info.sparkUser} - {lastUpdated} + {lastUpdated} } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala index 25fc76c23e0fb..b1270ade9f750 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -17,26 +17,33 @@ package org.apache.spark.deploy.history -import org.apache.spark.SparkConf +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils /** * Command-line parser for the master. */ -private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) { - private var logDir: String = null +private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging { + private var propertiesFile: String = null parse(args.toList) private def parse(args: List[String]): Unit = { args match { case ("--dir" | "-d") :: value :: tail => - logDir = value + logWarning("Setting log directory through the command line is deprecated as of " + + "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.") conf.set("spark.history.fs.logDirectory", value) + System.setProperty("spark.history.fs.logDirectory", value) parse(tail) case ("--help" | "-h") :: tail => printUsageAndExit(0) + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + case Nil => case _ => @@ -44,10 +51,17 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] } } + // This mutates the SparkConf, so all accesses to it must be made after this line + Utils.loadDefaultSparkProperties(conf, propertiesFile) + private def printUsageAndExit(exitCode: Int) { System.err.println( """ - |Usage: HistoryServer + |Usage: HistoryServer [options] + | + |Options: + | --properties-file FILE Path to a custom Spark properties file. + | Default is conf/spark-defaults.conf. | |Configuration options can be set by setting the corresponding JVM system property. |History Server options are always available; additional options depend on the provider. @@ -64,9 +78,10 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String] | (default 50) |FsHistoryProvider options: | - | spark.history.fs.logDirectory Directory where app logs are stored (required) - | spark.history.fs.updateInterval How often to reload log data from storage (in seconds, - | default 10) + | spark.history.fs.logDirectory Directory where app logs are stored + | (default: file:/tmp/spark-events) + | spark.history.fs.updateInterval How often to reload log data from storage + | (in seconds, default: 10) |""".stripMargin) System.exit(exitCode) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index c3ca43f8d0734..ad7d81747c377 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -24,7 +24,9 @@ import scala.collection.mutable.ArrayBuffer import akka.actor.ActorRef +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.util.Utils private[spark] class ApplicationInfo( val startTime: Long, @@ -46,7 +48,7 @@ private[spark] class ApplicationInfo( init() - private def readObject(in: java.io.ObjectInputStream): Unit = { + private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() init() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala index 80b570a44af18..9d3d7938c6ccb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala @@ -19,7 +19,9 @@ package org.apache.spark.deploy.master import java.util.Date +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.DriverDescription +import org.apache.spark.util.Utils private[spark] class DriverInfo( val startTime: Long, @@ -36,7 +38,7 @@ private[spark] class DriverInfo( init() - private def readObject(in: java.io.ObjectInputStream): Unit = { + private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() init() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index aa85aa060d9c1..36a2e2c6a6349 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -19,10 +19,13 @@ package org.apache.spark.deploy.master import java.io._ +import scala.reflect.ClassTag + import akka.serialization.Serialization import org.apache.spark.Logging + /** * Stores data in a single on-disk directory with one file per application and worker. * Files are deleted when applications and workers are removed. @@ -37,64 +40,43 @@ private[spark] class FileSystemPersistenceEngine( new File(dir).mkdir() - override def addApplication(app: ApplicationInfo) { - val appFile = new File(dir + File.separator + "app_" + app.id) - serializeIntoFile(appFile, app) - } - - override def removeApplication(app: ApplicationInfo) { - new File(dir + File.separator + "app_" + app.id).delete() - } - - override def addDriver(driver: DriverInfo) { - val driverFile = new File(dir + File.separator + "driver_" + driver.id) - serializeIntoFile(driverFile, driver) - } - - override def removeDriver(driver: DriverInfo) { - new File(dir + File.separator + "driver_" + driver.id).delete() - } - - override def addWorker(worker: WorkerInfo) { - val workerFile = new File(dir + File.separator + "worker_" + worker.id) - serializeIntoFile(workerFile, worker) + override def persist(name: String, obj: Object): Unit = { + serializeIntoFile(new File(dir + File.separator + name), obj) } - override def removeWorker(worker: WorkerInfo) { - new File(dir + File.separator + "worker_" + worker.id).delete() + override def unpersist(name: String): Unit = { + new File(dir + File.separator + name).delete() } - override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - val sortedFiles = new File(dir).listFiles().sortBy(_.getName) - val appFiles = sortedFiles.filter(_.getName.startsWith("app_")) - val apps = appFiles.map(deserializeFromFile[ApplicationInfo]) - val driverFiles = sortedFiles.filter(_.getName.startsWith("driver_")) - val drivers = driverFiles.map(deserializeFromFile[DriverInfo]) - val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_")) - val workers = workerFiles.map(deserializeFromFile[WorkerInfo]) - (apps, drivers, workers) + override def read[T: ClassTag](prefix: String) = { + val files = new File(dir).listFiles().filter(_.getName.startsWith(prefix)) + files.map(deserializeFromFile[T]) } private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } - val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) - val out = new FileOutputStream(file) - out.write(serialized) - out.close() + try { + out.write(serialized) + } finally { + out.close() + } } - def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = { + private def deserializeFromFile[T](file: File)(implicit m: ClassTag[T]): T = { val fileData = new Array[Byte](file.length().asInstanceOf[Int]) val dis = new DataInputStream(new FileInputStream(file)) - dis.readFully(fileData) - dis.close() - + try { + dis.readFully(fileData) + } finally { + dis.close() + } val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) serializer.fromBinary(fileData).asInstanceOf[T] } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index 4433a2ec29be6..cf77c86d760cf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -17,30 +17,27 @@ package org.apache.spark.deploy.master -import akka.actor.{Actor, ActorRef} - -import org.apache.spark.deploy.master.MasterMessages.ElectedLeader +import org.apache.spark.annotation.DeveloperApi /** - * A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it - * is the only Master serving requests. - * In addition to the API provided, the LeaderElectionAgent will use of the following messages - * to inform the Master of leader changes: - * [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]] - * [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]] + * :: DeveloperApi :: + * + * A LeaderElectionAgent tracks current master and is a common interface for all election Agents. */ -private[spark] trait LeaderElectionAgent extends Actor { - // TODO: LeaderElectionAgent does not necessary to be an Actor anymore, need refactoring. - val masterActor: ActorRef +@DeveloperApi +trait LeaderElectionAgent { + val masterActor: LeaderElectable + def stop() {} // to avoid noops in implementations. } -/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ -private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent { - override def preStart() { - masterActor ! ElectedLeader - } +@DeveloperApi +trait LeaderElectable { + def electedLeader() + def revokedLeadership() +} - override def receive = { - case _ => - } +/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ +private[spark] class MonarchyLeaderAgent(val masterActor: LeaderElectable) + extends LeaderElectionAgent { + masterActor.electedLeader() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index f98b531316a3d..7b32c505def9b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -30,6 +30,7 @@ import scala.util.Random import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -50,7 +51,7 @@ private[spark] class Master( port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { + extends Actor with ActorLogReceive with Logging with LeaderElectable { import context.dispatcher // to use Akka's scheduler.schedule() @@ -61,7 +62,6 @@ private[spark] class Master( val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) - val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") val RECOVERY_MODE = conf.get("spark.deploy.recoveryMode", "NONE") val workers = new HashSet[WorkerInfo] @@ -103,7 +103,7 @@ private[spark] class Master( var persistenceEngine: PersistenceEngine = _ - var leaderElectionAgent: ActorRef = _ + var leaderElectionAgent: LeaderElectionAgent = _ private var recoveryCompletionTask: Cancellable = _ @@ -130,23 +130,27 @@ private[spark] class Master( masterMetricsSystem.start() applicationMetricsSystem.start() - persistenceEngine = RECOVERY_MODE match { + val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match { case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") - new ZooKeeperPersistenceEngine(SerializationExtension(context.system), conf) + val zkFactory = + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => - logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) - new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system)) + val fsFactory = + new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) + case "CUSTOM" => + val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) + val factory = clazz.getConstructor(conf.getClass, Serialization.getClass) + .newInstance(conf, SerializationExtension(context.system)) + .asInstanceOf[StandaloneRecoveryModeFactory] + (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => - new BlackHolePersistenceEngine() + (new BlackHolePersistenceEngine(), new MonarchyLeaderAgent(this)) } - - leaderElectionAgent = RECOVERY_MODE match { - case "ZOOKEEPER" => - context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl, conf)) - case _ => - context.actorOf(Props(classOf[MonarchyLeaderAgent], self)) - } + persistenceEngine = persistenceEngine_ + leaderElectionAgent = leaderElectionAgent_ } override def preRestart(reason: Throwable, message: Option[Any]) { @@ -165,7 +169,15 @@ private[spark] class Master( masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() - context.stop(leaderElectionAgent) + leaderElectionAgent.stop() + } + + override def electedLeader() { + self ! ElectedLeader + } + + override def revokedLeadership() { + self ! RevokedLeadership } override def receiveWithLogging = { @@ -341,7 +353,14 @@ private[spark] class Master( case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() case None => - logWarning("Got heartbeat from unregistered worker " + workerId) + if (workers.map(_.id).contains(workerId)) { + logWarning(s"Got heartbeat from unregistered worker $workerId." + + " Asking it to re-register.") + sender ! ReconnectWorker(masterUrl) + } else { + logWarning(s"Got heartbeat from unregistered worker $workerId." + + " This worker was never registered, so ignoring the heartbeat.") + } } } @@ -714,8 +733,8 @@ private[spark] class Master( try { val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec) - val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)", - HistoryServer.UI_PATH_PREFIX + s"/${app.id}") + val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf), + appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}") replayBus.replay() appIdToUI(app.id) = ui webUi.attachSparkUI(ui) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 4b0dbbe543d3f..e34bee7854292 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -27,6 +27,7 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 7077 var webUiPort = 8080 + var propertiesFile: String = null // Check for settings in environment variables if (System.getenv("SPARK_MASTER_HOST") != null) { @@ -38,12 +39,16 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_MASTER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_MASTER_WEBUI_PORT").toInt } + + parse(args.toList) + + // This mutates the SparkConf, so all accesses to it must be made after this line + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + if (conf.contains("spark.master.ui.port")) { webUiPort = conf.get("spark.master.ui.port").toInt } - parse(args.toList) - def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) @@ -63,7 +68,11 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { webUiPort = value parse(tail) - case ("--help" | "-h") :: tail => + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => printUsageAndExit(0) case Nil => {} @@ -83,7 +92,9 @@ private[spark] class MasterArguments(args: Array[String], conf: SparkConf) { " -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: 7077)\n" + - " --webui-port PORT Port for web UI (default: 8080)") + " --webui-port PORT Port for web UI (default: 8080)\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") System.exit(exitCode) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index e3640ea4f7e64..2e0e1e7036ac8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -17,6 +17,10 @@ package org.apache.spark.deploy.master +import org.apache.spark.annotation.DeveloperApi + +import scala.reflect.ClassTag + /** * Allows Master to persist any state that is necessary in order to recover from a failure. * The following semantics are required: @@ -25,36 +29,70 @@ package org.apache.spark.deploy.master * Given these two requirements, we will have all apps and workers persisted, but * we might not have yet deleted apps or workers that finished (so their liveness must be verified * during recovery). + * + * The implementation of this trait defines how name-object pairs are stored or retrieved. */ -private[spark] trait PersistenceEngine { - def addApplication(app: ApplicationInfo) +@DeveloperApi +trait PersistenceEngine { - def removeApplication(app: ApplicationInfo) + /** + * Defines how the object is serialized and persisted. Implementation will + * depend on the store used. + */ + def persist(name: String, obj: Object) - def addWorker(worker: WorkerInfo) + /** + * Defines how the object referred by its name is removed from the store. + */ + def unpersist(name: String) - def removeWorker(worker: WorkerInfo) + /** + * Gives all objects, matching a prefix. This defines how objects are + * read/deserialized back. + */ + def read[T: ClassTag](prefix: String): Seq[T] - def addDriver(driver: DriverInfo) + final def addApplication(app: ApplicationInfo): Unit = { + persist("app_" + app.id, app) + } - def removeDriver(driver: DriverInfo) + final def removeApplication(app: ApplicationInfo): Unit = { + unpersist("app_" + app.id) + } + + final def addWorker(worker: WorkerInfo): Unit = { + persist("worker_" + worker.id, worker) + } + + final def removeWorker(worker: WorkerInfo): Unit = { + unpersist("worker_" + worker.id) + } + + final def addDriver(driver: DriverInfo): Unit = { + persist("driver_" + driver.id, driver) + } + + final def removeDriver(driver: DriverInfo): Unit = { + unpersist("driver_" + driver.id) + } /** * Returns the persisted data sorted by their respective ids (which implies that they're * sorted by time of creation). */ - def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) + final def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { + (read[ApplicationInfo]("app_"), read[DriverInfo]("driver_"), read[WorkerInfo]("worker_")) + } def close() {} } private[spark] class BlackHolePersistenceEngine extends PersistenceEngine { - override def addApplication(app: ApplicationInfo) {} - override def removeApplication(app: ApplicationInfo) {} - override def addWorker(worker: WorkerInfo) {} - override def removeWorker(worker: WorkerInfo) {} - override def addDriver(driver: DriverInfo) {} - override def removeDriver(driver: DriverInfo) {} - - override def readPersistedData() = (Nil, Nil, Nil) + + override def persist(name: String, obj: Object): Unit = {} + + override def unpersist(name: String): Unit = {} + + override def read[T: ClassTag](name: String): Seq[T] = Nil + } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala new file mode 100644 index 0000000000000..1096eb0368357 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryModeFactory.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master + +import akka.serialization.Serialization + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.annotation.DeveloperApi + +/** + * ::DeveloperApi:: + * + * Implementation of this class can be plugged in as recovery mode alternative for Spark's + * Standalone mode. + * + */ +@DeveloperApi +abstract class StandaloneRecoveryModeFactory(conf: SparkConf, serializer: Serialization) { + + /** + * PersistenceEngine defines how the persistent data(Information about worker, driver etc..) + * is handled for recovery. + * + */ + def createPersistenceEngine(): PersistenceEngine + + /** + * Create an instance of LeaderAgent that decides who gets elected as master. + */ + def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent +} + +/** + * LeaderAgent in this case is a no-op. Since leader is forever leader as the actual + * recovery is made by restoring from filesystem. + */ +private[spark] class FileSystemRecoveryModeFactory(conf: SparkConf, serializer: Serialization) + extends StandaloneRecoveryModeFactory(conf, serializer) with Logging { + val RECOVERY_DIR = conf.get("spark.deploy.recoveryDirectory", "") + + def createPersistenceEngine() = { + logInfo("Persisting recovery state to directory: " + RECOVERY_DIR) + new FileSystemPersistenceEngine(RECOVERY_DIR, serializer) + } + + def createLeaderElectionAgent(master: LeaderElectable) = new MonarchyLeaderAgent(master) +} + +private[spark] class ZooKeeperRecoveryModeFactory(conf: SparkConf, serializer: Serialization) + extends StandaloneRecoveryModeFactory(conf, serializer) { + def createPersistenceEngine() = new ZooKeeperPersistenceEngine(conf, serializer) + + def createLeaderElectionAgent(master: LeaderElectable) = + new ZooKeeperLeaderElectionAgent(master, conf) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index c5fa9cf7d7c2d..473ddc23ff0f3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import akka.actor.ActorRef +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -50,7 +51,7 @@ private[spark] class WorkerInfo( def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed - private def readObject(in: java.io.ObjectInputStream) : Unit = { + private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException { in.defaultReadObject() init() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 285f9b014e291..8eaa0ad948519 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -24,9 +24,8 @@ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} -private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, - masterUrl: String, conf: SparkConf) - extends LeaderElectionAgent with LeaderLatchListener with Logging { +private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: LeaderElectable, + conf: SparkConf) extends LeaderLatchListener with LeaderElectionAgent with Logging { val WORKING_DIR = conf.get("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" @@ -34,30 +33,21 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, private var leaderLatch: LeaderLatch = _ private var status = LeadershipStatus.NOT_LEADER - override def preStart() { + start() + def start() { logInfo("Starting ZooKeeper LeaderElection agent") zk = SparkCuratorUtil.newClient(conf) leaderLatch = new LeaderLatch(zk, WORKING_DIR) leaderLatch.addListener(this) - leaderLatch.start() } - override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) { - logError("LeaderElectionAgent failed...", reason) - super.preRestart(reason, message) - } - - override def postStop() { + override def stop() { leaderLatch.close() zk.close() } - override def receive = { - case _ => - } - override def isLeader() { synchronized { // could have lost leadership by now. @@ -85,10 +75,10 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, def updateLeadershipStatus(isLeader: Boolean) { if (isLeader && status == LeadershipStatus.NOT_LEADER) { status = LeadershipStatus.LEADER - masterActor ! ElectedLeader + masterActor.electedLeader() } else if (!isLeader && status == LeadershipStatus.LEADER) { status = LeadershipStatus.NOT_LEADER - masterActor ! RevokedLeadership + masterActor.revokedLeadership() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 834dfedee52ce..e11ac031fb9c6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -17,15 +17,18 @@ package org.apache.spark.deploy.master +import akka.serialization.Serialization + import scala.collection.JavaConversions._ +import scala.reflect.ClassTag -import akka.serialization.Serialization import org.apache.curator.framework.CuratorFramework import org.apache.zookeeper.CreateMode import org.apache.spark.{Logging, SparkConf} -class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) + +private[spark] class ZooKeeperPersistenceEngine(conf: SparkConf, val serialization: Serialization) extends PersistenceEngine with Logging { @@ -34,52 +37,31 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) SparkCuratorUtil.mkdir(zk, WORKING_DIR) - override def addApplication(app: ApplicationInfo) { - serializeIntoFile(WORKING_DIR + "/app_" + app.id, app) - } - override def removeApplication(app: ApplicationInfo) { - zk.delete().forPath(WORKING_DIR + "/app_" + app.id) + override def persist(name: String, obj: Object): Unit = { + serializeIntoFile(WORKING_DIR + "/" + name, obj) } - override def addDriver(driver: DriverInfo) { - serializeIntoFile(WORKING_DIR + "/driver_" + driver.id, driver) + override def unpersist(name: String): Unit = { + zk.delete().forPath(WORKING_DIR + "/" + name) } - override def removeDriver(driver: DriverInfo) { - zk.delete().forPath(WORKING_DIR + "/driver_" + driver.id) - } - - override def addWorker(worker: WorkerInfo) { - serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker) - } - - override def removeWorker(worker: WorkerInfo) { - zk.delete().forPath(WORKING_DIR + "/worker_" + worker.id) + override def read[T: ClassTag](prefix: String) = { + val file = zk.getChildren.forPath(WORKING_DIR).filter(_.startsWith(prefix)) + file.map(deserializeFromFile[T]).flatten } override def close() { zk.close() } - override def readPersistedData(): (Seq[ApplicationInfo], Seq[DriverInfo], Seq[WorkerInfo]) = { - val sortedFiles = zk.getChildren().forPath(WORKING_DIR).toList.sorted - val appFiles = sortedFiles.filter(_.startsWith("app_")) - val apps = appFiles.map(deserializeFromFile[ApplicationInfo]).flatten - val driverFiles = sortedFiles.filter(_.startsWith("driver_")) - val drivers = driverFiles.map(deserializeFromFile[DriverInfo]).flatten - val workerFiles = sortedFiles.filter(_.startsWith("worker_")) - val workers = workerFiles.map(deserializeFromFile[WorkerInfo]).flatten - (apps, drivers, workers) - } - private def serializeIntoFile(path: String, value: AnyRef) { val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) zk.create().withMode(CreateMode.PERSISTENT).forPath(path, serialized) } - def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): Option[T] = { + def deserializeFromFile[T](filename: String)(implicit m: ClassTag[T]): Option[T] = { val fileData = zk.getData().forPath(WORKING_DIR + "/" + filename) val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 2e9be2a180c68..28e9662db5da9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -20,6 +20,8 @@ package org.apache.spark.deploy.worker import java.io.{File, FileOutputStream, InputStream, IOException} import java.lang.System._ +import scala.collection.Map + import org.apache.spark.Logging import org.apache.spark.deploy.Command import org.apache.spark.util.Utils @@ -29,7 +31,29 @@ import org.apache.spark.util.Utils */ private[spark] object CommandUtils extends Logging { - def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { + + /** + * Build a ProcessBuilder based on the given parameters. + * The `env` argument is exposed for testing. + */ + def buildProcessBuilder( + command: Command, + memory: Int, + sparkHome: String, + substituteArguments: String => String, + classPaths: Seq[String] = Seq[String](), + env: Map[String, String] = sys.env): ProcessBuilder = { + val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env) + val commandSeq = buildCommandSeq(localCommand, memory, sparkHome) + val builder = new ProcessBuilder(commandSeq: _*) + val environment = builder.environment() + for ((key, value) <- localCommand.environment) { + environment.put(key, value) + } + builder + } + + private def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") // SPARK-698: do not call the run.cmd script, as process.destroy() @@ -38,11 +62,41 @@ object CommandUtils extends Logging { command.arguments } + /** + * Build a command based on the given one, taking into account the local environment + * of where this command is expected to run, substitute any placeholders, and append + * any extra class paths. + */ + private def buildLocalCommand( + command: Command, + substituteArguments: String => String, + classPath: Seq[String] = Seq[String](), + env: Map[String, String]): Command = { + val libraryPathName = Utils.libraryPathEnvName + val libraryPathEntries = command.libraryPathEntries + val cmdLibraryPath = command.environment.get(libraryPathName) + + val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) { + val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName) + command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator))) + } else { + command.environment + } + + Command( + command.mainClass, + command.arguments.map(substituteArguments), + newEnvironment, + command.classPathEntries ++ classPath, + Seq[String](), // library path already captured in environment variable + command.javaOpts) + } + /** * Attention: this must always be aligned with the environment variables in the run scripts and * the way the JAVA_OPTS are assembled there. */ - def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { + private def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = { val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M") // Exists for backwards compatibility with older Spark versions @@ -53,14 +107,6 @@ object CommandUtils extends Logging { logWarning("Set SPARK_LOCAL_DIRS for node-specific storage locations.") } - val libraryOpts = - if (command.libraryPathEntries.size > 0) { - val joined = command.libraryPathEntries.mkString(File.pathSeparator) - Seq(s"-Djava.library.path=$joined") - } else { - Seq() - } - // Figure out our classpath with the external compute-classpath script val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh" val classPath = Utils.executeAndGetOutput( @@ -71,7 +117,7 @@ object CommandUtils extends Logging { val javaVersion = System.getProperty("java.version") val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++ - permGenOpt ++ libraryOpts ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts + permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts } /** Spawn a thread that will redirect a given stream to a file */ diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 9f9911762505a..28cab36c7b9e2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -23,7 +23,7 @@ import scala.collection.JavaConversions._ import scala.collection.Map import akka.actor.ActorRef -import com.google.common.base.Charsets +import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileUtil, Path} @@ -76,17 +76,9 @@ private[spark] class DriverRunner( // Make sure user application jar is on the classpath // TODO: If we add ability to submit multiple jars they should also be added here - val classPath = driverDesc.command.classPathEntries ++ Seq(s"$localJarFilename") - val newCommand = Command( - driverDesc.command.mainClass, - driverDesc.command.arguments.map(substituteVariables), - driverDesc.command.environment, - classPath, - driverDesc.command.libraryPathEntries, - driverDesc.command.javaOpts) - val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem, - sparkHome.getAbsolutePath) - launchDriver(command, driverDesc.command.environment, driverDir, driverDesc.supervise) + val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem, + sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename)) + launchDriver(builder, driverDir, driverDesc.supervise) } catch { case e: Exception => finalException = Some(e) @@ -165,11 +157,8 @@ private[spark] class DriverRunner( localJarFilename } - private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File, - supervise: Boolean) { - val builder = new ProcessBuilder(command: _*).directory(baseDir) - envVars.map{ case(k,v) => builder.environment().put(k, v) } - + private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) { + builder.directory(baseDir) def initialize(process: Process) = { // Redirect stdout and stderr to files val stdout = new File(baseDir, "stdout") @@ -177,8 +166,8 @@ private[spark] class DriverRunner( val stderr = new File(baseDir, "stderr") val header = "Launch Command: %s\n%s\n\n".format( - command.mkString("\"", "\" \"", "\""), "=" * 40) - Files.append(header, stderr, Charsets.UTF_8) + builder.command.mkString("\"", "\" \"", "\""), "=" * 40) + Files.append(header, stderr, UTF_8) CommandUtils.redirectStream(process.getErrorStream, stderr) } runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 71650cd773bcf..8ba6a01bbcb97 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -19,8 +19,10 @@ package org.apache.spark.deploy.worker import java.io._ +import scala.collection.JavaConversions._ + import akka.actor.ActorRef -import com.google.common.base.Charsets +import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.spark.{SparkConf, Logging} @@ -111,36 +113,25 @@ private[spark] class ExecutorRunner( case "{{EXECUTOR_ID}}" => execId.toString case "{{HOSTNAME}}" => host case "{{CORES}}" => cores.toString + case "{{APP_ID}}" => appId case other => other } - def getCommandSeq = { - val command = Command( - appDesc.command.mainClass, - appDesc.command.arguments.map(substituteVariables) ++ Seq(appId), - appDesc.command.environment, - appDesc.command.classPathEntries, - appDesc.command.libraryPathEntries, - appDesc.command.javaOpts) - CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath) - } - /** * Download and run the executor described in our ApplicationDescription */ def fetchAndRunExecutor() { try { // Launch the process - val command = getCommandSeq + val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory, + sparkHome.getAbsolutePath, substituteVariables) + val command = builder.command() logInfo("Launch command: " + command.mkString("\"", "\" \"", "\"")) - val builder = new ProcessBuilder(command: _*).directory(executorDir) - val env = builder.environment() - for ((key, value) <- appDesc.command.environment) { - env.put(key, value) - } + + builder.directory(executorDir) // In case we are running this from within the Spark Shell, avoid creating a "scala" // parent process for the executor command - env.put("SPARK_LAUNCH_WITH_SCALA", "0") + builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0") process = builder.start() val header = "Spark Executor Command: %s\n%s\n\n".format( command.mkString("\"", "\" \"", "\""), "=" * 40) @@ -150,7 +141,7 @@ private[spark] class ExecutorRunner( stdoutAppender = FileAppender(process.getInputStream, stdout, conf) val stderr = new File(executorDir, "stderr") - Files.write(header, stderr, Charsets.UTF_8) + Files.write(header, stderr, UTF_8) stderrAppender = FileAppender(process.getErrorStream, stderr, conf) state = ExecutorState.RUNNING diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala new file mode 100644 index 0000000000000..b9798963bab0a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker + +import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.sasl.SaslRpcHandler +import org.apache.spark.network.server.TransportServer +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler + +/** + * Provides a server from which Executors can read shuffle files (rather than reading directly from + * each other), to provide uninterrupted access to the files in the face of executors being turned + * off or killed. + * + * Optionally requires SASL authentication in order to read. See [[SecurityManager]]. + */ +private[worker] +class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager) + extends Logging { + + private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false) + private val port = sparkConf.getInt("spark.shuffle.service.port", 7337) + private val useSasl: Boolean = securityManager.isAuthenticationEnabled() + + private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0) + private val blockHandler = new ExternalShuffleBlockHandler(transportConf) + private val transportContext: TransportContext = { + val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler + new TransportContext(transportConf, handler) + } + + private var server: TransportServer = _ + + /** Starts the external shuffle service if the user has configured us to. */ + def startIfEnabled() { + if (enabled) { + require(server == null, "Shuffle server already started") + logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl") + server = transportContext.createServer(port) + } + } + + def stop() { + if (enabled && server != null) { + server.close() + server = null + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 3b13f43a1868c..eb11163538b20 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,16 +20,16 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.Date +import java.util.{UUID, Date} import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import scala.concurrent.duration._ import scala.language.postfixOps +import scala.util.Random import akka.actor._ import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.commons.io.FileUtils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} @@ -65,8 +65,22 @@ private[spark] class Worker( // Send a heartbeat every (heartbeat timeout) / 4 milliseconds val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 - val REGISTRATION_TIMEOUT = 20.seconds - val REGISTRATION_RETRIES = 3 + // Model retries to connect to the master, after Hadoop's model. + // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds) + // Afterwards, the next 10 attempts are between 30 and 90 seconds. + // A bit of randomness is introduced so that not all of the workers attempt to reconnect at + // the same time. + val INITIAL_REGISTRATION_RETRIES = 6 + val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10 + val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500 + val REGISTRATION_RETRY_FUZZ_MULTIPLIER = { + val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) + randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND + } + val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders @@ -96,6 +110,9 @@ private[spark] class Worker( val drivers = new HashMap[String, DriverRunner] val finishedDrivers = new HashMap[String, DriverRunner] + // The shuffle service is not actually started unless configured. + val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr) + val publicAddress = { val envVar = System.getenv("SPARK_PUBLIC_DNS") if (envVar != null) envVar else host @@ -104,6 +121,7 @@ private[spark] class Worker( var coresUsed = 0 var memoryUsed = 0 + var connectionAttemptCount = 0 val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) val workerSource = new WorkerSource(this) @@ -138,6 +156,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() registerWithMaster() @@ -157,9 +176,12 @@ private[spark] class Worker( throw new SparkException("Invalid spark URL: " + x) } connected = true + // Cancel any outstanding re-registration attempts because we found a new master + registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer = None } - def tryRegisterAllMasters() { + private def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) @@ -167,26 +189,80 @@ private[spark] class Worker( } } - def registerWithMaster() { - tryRegisterAllMasters() - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { - Utils.tryOrExit { - retries += 1 - if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { - logError("All masters are unresponsive! Giving up.") - System.exit(1) - } else { - tryRegisterAllMasters() + /** + * Re-register with the master because a network failure or a master failure has occurred. + * If the re-registration attempt threshold is exceeded, the worker exits with error. + * Note that for thread-safety this should only be called from the actor. + */ + private def reregisterWithMaster(): Unit = { + Utils.tryOrExit { + connectionAttemptCount += 1 + if (registered) { + registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer = None + } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { + logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") + /** + * Re-register with the active master this worker has been communicating with. If there + * is none, then it means this worker is still bootstrapping and hasn't established a + * connection with a master yet, in which case we should re-register with all masters. + * + * It is important to re-register only with the active master during failures. Otherwise, + * if the worker unconditionally attempts to re-register with all masters, the following + * race condition may arise and cause a "duplicate worker" error detailed in SPARK-4592: + * + * (1) Master A fails and Worker attempts to reconnect to all masters + * (2) Master B takes over and notifies Worker + * (3) Worker responds by registering with Master B + * (4) Meanwhile, Worker's previous reconnection attempt reaches Master B, + * causing the same Worker to register with Master B twice + * + * Instead, if we only register with the known active master, we can assume that the + * old master must have died because another master has taken over. Note that this is + * still not safe if the old master recovers within this interval, but this is a much + * less likely scenario. + */ + if (master != null) { + master ! RegisterWorker( + workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + } else { + // We are retrying the initial registration + tryRegisterAllMasters() + } + // We have exceeded the initial registration retry threshold + // All retries from now on should use a higher interval + if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { + registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer = Some { + context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, + PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) } } + } else { + logError("All masters are unresponsive! Giving up.") + System.exit(1) } } } + def registerWithMaster() { + // DisassociatedEvent may be triggered multiple times, so don't attempt registration + // if there are outstanding registration attempts scheduled. + registrationRetryTimer match { + case None => + registered = false + tryRegisterAllMasters() + connectionAttemptCount = 0 + registrationRetryTimer = Some { + context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, + INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) + } + case Some(_) => + logInfo("Not spawning another attempt to register with the master, since there is an" + + " attempt scheduled already.") + } + } + override def receiveWithLogging = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) @@ -244,6 +320,10 @@ private[spark] class Worker( System.exit(1) } + case ReconnectWorker(masterUrl) => + logInfo(s"Master with url $masterUrl requested this worker to reconnect.") + registerWithMaster() + case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_) => if (masterUrl != activeMasterUrl) { logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.") @@ -355,17 +435,21 @@ private[spark] class Worker( logInfo(s"$x Disassociated !") masterDisconnected() - case RequestWorkerState => { + case RequestWorkerState => sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, finishedExecutors.values.toList, drivers.values.toList, finishedDrivers.values.toList, activeMasterUrl, cores, memory, coresUsed, memoryUsed, activeMasterWebUiUrl) - } + + case ReregisterWithMaster => + reregisterWithMaster() + } - def masterDisconnected() { + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false + registerWithMaster() } def generateWorkerId(): String = { @@ -377,6 +461,7 @@ private[spark] class Worker( registrationRetryTimer.foreach(_.cancel()) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) + shuffleService.stop() webUi.stop() metricsSystem.stop() } @@ -399,7 +484,8 @@ private[spark] object Worker extends Logging { cores: Int, memory: Int, masterUrls: Array[String], - workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workDir: String, + workerNumber: Option[Int] = None): (ActorSystem, Int) = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 1e295aaa48c30..019cd70f2a229 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -33,6 +33,7 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { var memory = inferDefaultMemory() var masters: Array[String] = null var workDir: String = null + var propertiesFile: String = null // Check for settings in environment variables if (System.getenv("SPARK_WORKER_PORT") != null) { @@ -41,21 +42,27 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { if (System.getenv("SPARK_WORKER_CORES") != null) { cores = System.getenv("SPARK_WORKER_CORES").toInt } - if (System.getenv("SPARK_WORKER_MEMORY") != null) { - memory = Utils.memoryStringToMb(System.getenv("SPARK_WORKER_MEMORY")) + if (conf.getenv("SPARK_WORKER_MEMORY") != null) { + memory = Utils.memoryStringToMb(conf.getenv("SPARK_WORKER_MEMORY")) } if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) { webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt } - if (conf.contains("spark.worker.ui.port")) { - webUiPort = conf.get("spark.worker.ui.port").toInt - } if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } parse(args.toList) + // This mutates the SparkConf, so all accesses to it must be made after this line + propertiesFile = Utils.loadDefaultSparkProperties(conf, propertiesFile) + + if (conf.contains("spark.worker.ui.port")) { + webUiPort = conf.get("spark.worker.ui.port").toInt + } + + checkWorkerMemory() + def parse(args: List[String]): Unit = args match { case ("--ip" | "-i") :: value :: tail => Utils.checkHost(value, "ip no longer supported, please use hostname " + value) @@ -87,7 +94,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { webUiPort = value parse(tail) - case ("--help" | "-h") :: tail => + case ("--properties-file") :: value :: tail => + propertiesFile = value + parse(tail) + + case ("--help") :: tail => printUsageAndExit(0) case value :: tail => @@ -122,7 +133,9 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { " -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" + " -h HOST, --host HOST Hostname to listen on\n" + " -p PORT, --port PORT Port to listen on (default: random)\n" + - " --webui-port PORT Port for web UI (default: 8081)") + " --webui-port PORT Port for web UI (default: 8081)\n" + + " --properties-file FILE Path to a custom Spark properties file.\n" + + " Default is conf/spark-defaults.conf.") System.exit(exitCode) } @@ -153,4 +166,11 @@ private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, 512) } + + def checkWorkerMemory(): Unit = { + if (memory <= 0) { + val message = "Memory can't be 0, missing a M or G on the end of the memory specification?" + throw new IllegalStateException(message) + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 6d0d0bbe5ecec..63a8ac817b618 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,7 +54,7 @@ private[spark] class WorkerWatcher(workerUrl: String) case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => logInfo(s"Successfully connected to $workerUrl") - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound) + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) if isWorker(remoteAddress) => // These logs may not be seen if the worker (and associated pipe) has died logError(s"Could not initialize connection to worker $workerUrl. Exiting.") diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 06061edfc0844..5f46f3b1f085e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import scala.concurrent.Await -import akka.actor.{Actor, ActorSelection, Props} +import akka.actor.{Actor, ActorSelection, ActorSystem, Props} import akka.pattern.Patterns import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} @@ -38,7 +38,8 @@ private[spark] class CoarseGrainedExecutorBackend( executorId: String, hostPort: String, cores: Int, - sparkProperties: Seq[(String, String)]) + sparkProperties: Seq[(String, String)], + actorSystem: ActorSystem) extends Actor with ActorLogReceive with ExecutorBackend with Logging { Utils.checkHostPort(hostPort, "Expected hostport") @@ -56,9 +57,9 @@ private[spark] class CoarseGrainedExecutorBackend( override def receiveWithLogging = { case RegisteredExecutor => logInfo("Successfully registered with driver") - // Make this host instead of hostPort ? - executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties, - false) + val (hostname, _) = Utils.parseHostPort(hostPort) + executor = new Executor(executorId, hostname, sparkProperties, cores, isLocal = false, + actorSystem) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -130,12 +131,13 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Create a new ActorSystem using driver's Spark properties to run the backend. val driverConf = new SparkConf().setAll(props) val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf)) + SparkEnv.executorActorSystemName, + hostname, port, driverConf, new SecurityManager(driverConf)) // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, props), + driverUrl, executorId, sparkHostPort, cores, props, actorSystem), name = "Executor") workerUrl.foreach { url => actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") @@ -152,6 +154,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { "Usage: CoarseGrainedExecutorBackend " + " [] ") System.exit(1) + + // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode) + // and CoarseMesosSchedulerBackend (for mesos mode). case 5 => run(args(0), args(1), args(2), args(3).toInt, args(4), None) case x if x > 5 => diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9bbfcdc4a0b6e..835157fc520aa 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,21 +26,26 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal +import akka.actor.{Props, ActorSystem} + import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils} /** * Spark executor used with Mesos, YARN, and the standalone scheduler. + * In coarse-grained mode, an existing actor system is provided. */ private[spark] class Executor( executorId: String, slaveHostname: String, properties: Seq[(String, String)], - isLocal: Boolean = false) + numCores: Int, + isLocal: Boolean = false, + actorSystem: ActorSystem = null) extends Logging { // Application dependencies (added through SparkContext) that we've fetched so far on this node. @@ -68,25 +73,31 @@ private[spark] class Executor( // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler(ExecutorUncaughtExceptionHandler) + Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) } val executorSource = new ExecutorSource(this, executorId) // Initialize Spark environment (using system properties read above) - conf.set("spark.executor.id", "executor." + executorId) + conf.set("spark.executor.id", executorId) private val env = { if (!isLocal) { - val _env = SparkEnv.create(conf, executorId, slaveHostname, 0, - isDriver = false, isLocal = false) + val port = conf.getInt("spark.executor.port", 0) + val _env = SparkEnv.createExecutorEnv( + conf, executorId, slaveHostname, port, numCores, isLocal, actorSystem) SparkEnv.set(_env) _env.metricsSystem.registerSource(executorSource) + _env.blockManager.initialize(conf.getAppId) _env } else { SparkEnv.get } } + // Create an actor for receiving RPCs from the driver + private val executorActor = env.actorSystem.actorOf( + Props(new ExecutorActor(executorId)), "ExecutorActor") + // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager private val urlClassLoader = createClassLoader() @@ -99,6 +110,9 @@ private[spark] class Executor( // to send the result back. private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) + // Limit of bytes for total size of results (default is 1GB) + private val maxResultSize = Utils.getMaxResultSize(conf) + // Start worker thread pool val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker") @@ -123,6 +137,7 @@ private[spark] class Executor( def stop() { env.metricsSystem.report() + env.actorSystem.stop(executorActor) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -147,8 +162,7 @@ private[spark] class Executor( } override def run() { - val startTime = System.currentTimeMillis() - SparkEnv.set(env) + val deserializeStartTime = System.currentTimeMillis() Thread.currentThread.setContextClassLoader(replClassLoader) val ser = SparkEnv.get.closureSerializer.newInstance() logInfo(s"Running $taskName (TID $taskId)") @@ -158,7 +172,6 @@ private[spark] class Executor( val startGCTime = gcTime try { - SparkEnv.set(env) Accumulators.clear() val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) @@ -194,7 +207,7 @@ private[spark] class Executor( val afterSerialization = System.currentTimeMillis() for (m <- task.metrics) { - m.executorDeserializeTime = taskStart - startTime + m.executorDeserializeTime = taskStart - deserializeStartTime m.executorRunTime = taskFinish - taskStart m.jvmGCTime = gcTime - startGCTime m.resultSerializationTime = afterSerialization - beforeSerialization @@ -207,25 +220,27 @@ private[spark] class Executor( val resultSize = serializedDirectResult.limit // directSend = sending directly back to the driver - val (serializedResult, directSend) = { - if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { + val serializedResult = { + if (maxResultSize > 0 && resultSize > maxResultSize) { + logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " + + s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " + + s"dropping it.") + ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize)) + } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) { val blockId = TaskResultBlockId(taskId) env.blockManager.putBytes( blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER) - (ser.serialize(new IndirectTaskResult[Any](blockId)), false) + logInfo( + s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") + ser.serialize(new IndirectTaskResult[Any](blockId, resultSize)) } else { - (serializedDirectResult, true) + logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") + serializedDirectResult } } execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) - if (directSend) { - logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver") - } else { - logInfo( - s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)") - } } catch { case ffe: FetchFailedException => { val reason = ffe.toTaskEndReason @@ -249,13 +264,13 @@ private[spark] class Executor( m.executorRunTime = serviceTime m.jvmGCTime = gcTime - startGCTime } - val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics) + val reason = new ExceptionFailure(t, metrics) execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { - ExecutorUncaughtExceptionHandler.uncaughtException(t) + SparkUncaughtExceptionHandler.uncaughtException(t) } } } finally { @@ -319,19 +334,21 @@ private[spark] class Executor( * SparkContext. Also adds any new JARs we fetched to the class loader. */ private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) { - val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) + lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) synchronized { // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager, - hadoopConf) + // Fetch file with useCache mode, close cache for local mode. + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, + env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager, - hadoopConf) + // Fetch file with useCache mode, close cache for local mode. + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, + env.securityManager, hadoopConf, timestamp, useCache = !isLocal) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala new file mode 100644 index 0000000000000..41925f7e97e84 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import akka.actor.Actor +import org.apache.spark.Logging + +import org.apache.spark.util.{Utils, ActorLogReceive} + +/** + * Driver -> Executor message to trigger a thread dump. + */ +private[spark] case object TriggerThreadDump + +/** + * Actor that runs inside of executors to enable driver -> executor RPC. + */ +private[spark] +class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { + + override def receiveWithLogging = { + case TriggerThreadDump => + sender ! Utils.getThreadDump() + } + +} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index 38be2c58b333f..52862ae0ca5e4 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -17,6 +17,8 @@ package org.apache.spark.executor +import org.apache.spark.util.SparkExitCode._ + /** * These are exit codes that executors should use to provide the master with information about * executor failures assuming that cluster management framework can capture the exit codes (but @@ -27,16 +29,6 @@ package org.apache.spark.executor */ private[spark] object ExecutorExitCode { - /** The default uncaught exception handler was reached. */ - val UNCAUGHT_EXCEPTION = 50 - - /** The default uncaught exception handler was called and an exception was encountered while - logging the exception. */ - val UNCAUGHT_EXCEPTION_TWICE = 51 - - /** The default uncaught exception handler was reached, and the uncaught exception was an - OutOfMemoryError. */ - val OOM = 52 /** DiskStore failed to create a local temporary directory after many attempts. */ val DISK_STORE_FAILED_TO_CREATE_DIR = 53 diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala deleted file mode 100644 index b0e984c03964c..0000000000000 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.executor - -import org.apache.spark.Logging -import org.apache.spark.util.Utils - -/** - * The default uncaught exception handler for Executors terminates the whole process, to avoid - * getting into a bad state indefinitely. Since Executors are relatively lightweight, it's better - * to fail fast when things go wrong. - */ -private[spark] object ExecutorUncaughtExceptionHandler - extends Thread.UncaughtExceptionHandler with Logging { - - override def uncaughtException(thread: Thread, exception: Throwable) { - try { - logError("Uncaught exception in thread " + thread, exception) - - // We may have been called from a shutdown hook. If so, we must not call System.exit(). - // (If we do, we will deadlock.) - if (!Utils.inShutdown()) { - if (exception.isInstanceOf[OutOfMemoryError]) { - System.exit(ExecutorExitCode.OOM) - } else { - System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION) - } - } - } catch { - case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM) - case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE) - } - } - - def uncaughtException(exception: Throwable) { - uncaughtException(Thread.currentThread(), exception) - } -} diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index bca0b152268ad..f15e6bc33fb41 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -19,6 +19,8 @@ package org.apache.spark.executor import java.nio.ByteBuffer +import scala.collection.JavaConversions._ + import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver, MesosNativeLibrary} import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} @@ -50,14 +52,23 @@ private[spark] class MesosExecutorBackend executorInfo: ExecutorInfo, frameworkInfo: FrameworkInfo, slaveInfo: SlaveInfo) { - logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue) + + // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend. + val cpusPerTask = executorInfo.getResourcesList + .find(_.getName == "cpus") + .map(_.getScalar.getValue.toInt) + .getOrElse(0) + val executorId = executorInfo.getExecutorId.getValue + + logInfo(s"Registered with Mesos as executor ID $executorId with $cpusPerTask cpus") this.driver = driver val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++ Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue)) executor = new Executor( - executorInfo.getExecutorId.getValue, + executorId, slaveInfo.getHostname, - properties) + properties, + cpusPerTask) } override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) { diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 3e49b6235aff3..51b5328cb4c8f 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -82,6 +82,12 @@ class TaskMetrics extends Serializable { */ var inputMetrics: Option[InputMetrics] = None + /** + * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much + * data was written are stored here. + */ + var outputMetrics: Option[OutputMetrics] = None + /** * If this task reads from shuffle output, metrics on getting shuffle data will be collected here. * This includes read metrics aggregated over all the task's shuffle dependencies. @@ -157,6 +163,16 @@ object DataReadMethod extends Enumeration with Serializable { val Memory, Disk, Hadoop, Network = Value } +/** + * :: DeveloperApi :: + * Method by which output data was written. + */ +@DeveloperApi +object DataWriteMethod extends Enumeration with Serializable { + type DataWriteMethod = Value + val Hadoop = Value +} + /** * :: DeveloperApi :: * Metrics about reading input data. @@ -169,6 +185,17 @@ case class InputMetrics(readMethod: DataReadMethod.Value) { var bytesRead: Long = 0L } +/** + * :: DeveloperApi :: + * Metrics about writing output data. + */ +@DeveloperApi +case class OutputMetrics(writeMethod: DataWriteMethod.Value) { + /** + * Total bytes written + */ + var bytesWritten: Long = 0L +} /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala new file mode 100644 index 0000000000000..89b29af2000c8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} + +/** + * Custom Input Format for reading and splitting flat binary files that contain records, + * each of which are a fixed size in bytes. The fixed record size is specified through + * a parameter recordLength in the Hadoop configuration. + */ +private[spark] object FixedLengthBinaryInputFormat { + /** Property name to set in Hadoop JobConfs for record length */ + val RECORD_LENGTH_PROPERTY = "org.apache.spark.input.FixedLengthBinaryInputFormat.recordLength" + + /** Retrieves the record length property from a Hadoop configuration */ + def getRecordLength(context: JobContext): Int = { + context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt + } +} + +private[spark] class FixedLengthBinaryInputFormat + extends FileInputFormat[LongWritable, BytesWritable] { + + private var recordLength = -1 + + /** + * Override of isSplitable to ensure initial computation of the record length + */ + override def isSplitable(context: JobContext, filename: Path): Boolean = { + if (recordLength == -1) { + recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) + } + if (recordLength <= 0) { + println("record length is less than 0, file cannot be split") + false + } else { + true + } + } + + /** + * This input format overrides computeSplitSize() to make sure that each split + * only contains full records. Each InputSplit passed to FixedLengthBinaryRecordReader + * will start at the first byte of a record, and the last byte will the last byte of a record. + */ + override def computeSplitSize(blockSize: Long, minSize: Long, maxSize: Long): Long = { + val defaultSize = super.computeSplitSize(blockSize, minSize, maxSize) + // If the default size is less than the length of a record, make it equal to it + // Otherwise, make sure the split size is as close to possible as the default size, + // but still contains a complete set of records, with the first record + // starting at the first byte in the split and the last record ending with the last byte + if (defaultSize < recordLength) { + recordLength.toLong + } else { + (Math.floor(defaultSize / recordLength) * recordLength).toLong + } + } + + /** + * Create a FixedLengthBinaryRecordReader + */ + override def createRecordReader(split: InputSplit, context: TaskAttemptContext) + : RecordReader[LongWritable, BytesWritable] = { + new FixedLengthBinaryRecordReader + } +} diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala new file mode 100644 index 0000000000000..36a1e5d475f46 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.IOException + +import org.apache.hadoop.fs.FSDataInputStream +import org.apache.hadoop.io.compress.CompressionCodecFactory +import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileSplit + +/** + * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat. + * It uses the record length set in FixedLengthBinaryInputFormat to + * read one record at a time from the given InputSplit. + * + * Each call to nextKeyValue() updates the LongWritable key and BytesWritable value. + * + * key = record index (Long) + * value = the record itself (BytesWritable) + */ +private[spark] class FixedLengthBinaryRecordReader + extends RecordReader[LongWritable, BytesWritable] { + + private var splitStart: Long = 0L + private var splitEnd: Long = 0L + private var currentPosition: Long = 0L + private var recordLength: Int = 0 + private var fileInputStream: FSDataInputStream = null + private var recordKey: LongWritable = null + private var recordValue: BytesWritable = null + + override def close() { + if (fileInputStream != null) { + fileInputStream.close() + } + } + + override def getCurrentKey: LongWritable = { + recordKey + } + + override def getCurrentValue: BytesWritable = { + recordValue + } + + override def getProgress: Float = { + splitStart match { + case x if x == splitEnd => 0.0.toFloat + case _ => Math.min( + ((currentPosition - splitStart) / (splitEnd - splitStart)).toFloat, 1.0 + ).toFloat + } + } + + override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) { + // the file input + val fileSplit = inputSplit.asInstanceOf[FileSplit] + + // the byte position this fileSplit starts at + splitStart = fileSplit.getStart + + // splitEnd byte marker that the fileSplit ends at + splitEnd = splitStart + fileSplit.getLength + + // the actual file we will be reading from + val file = fileSplit.getPath + // job configuration + val job = context.getConfiguration + // check compression + val codec = new CompressionCodecFactory(job).getCodec(file) + if (codec != null) { + throw new IOException("FixedLengthRecordReader does not support reading compressed files") + } + // get the record length + recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) + // get the filesystem + val fs = file.getFileSystem(job) + // open the File + fileInputStream = fs.open(file) + // seek to the splitStart position + fileInputStream.seek(splitStart) + // set our current position + currentPosition = splitStart + } + + override def nextKeyValue(): Boolean = { + if (recordKey == null) { + recordKey = new LongWritable() + } + // the key is a linear index of the record, given by the + // position the record starts divided by the record length + recordKey.set(currentPosition / recordLength) + // the recordValue to place the bytes into + if (recordValue == null) { + recordValue = new BytesWritable(new Array[Byte](recordLength)) + } + // read a record if the currentPosition is less than the split end + if (currentPosition < splitEnd) { + // setup a buffer to store the record + val buffer = recordValue.getBytes + fileInputStream.readFully(buffer) + // update our current position + currentPosition = currentPosition + recordLength + // return true + return true + } + false + } +} diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala new file mode 100644 index 0000000000000..457472547fcbb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} + +import scala.collection.JavaConversions._ + +import com.google.common.io.ByteStreams +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} + +import org.apache.spark.annotation.Experimental + +/** + * A general format for reading whole files in as streams, byte arrays, + * or other functions to be added + */ +private[spark] abstract class StreamFileInputFormat[T] + extends CombineFileInputFormat[String, T] +{ + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + + /** + * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API + * which is set through setMaxSplitSize + */ + def setMinPartitions(context: JobContext, minPartitions: Int) { + val files = listStatus(context) + val totalLen = files.map { file => + if (file.isDir) 0L else file.getLen + }.sum + + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong + super.setMaxSplitSize(maxSplitSize) + } + + def createRecordReader(split: InputSplit, taContext: TaskAttemptContext): RecordReader[String, T] + +} + +/** + * An abstract class of [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] + * to reading files out as streams + */ +private[spark] abstract class StreamBasedRecordReader[T]( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends RecordReader[String, T] { + + // True means the current file has been processed, then skip it. + private var processed = false + + private var key = "" + private var value: T = null.asInstanceOf[T] + + override def initialize(split: InputSplit, context: TaskAttemptContext) = {} + override def close() = {} + + override def getProgress = if (processed) 1.0f else 0.0f + + override def getCurrentKey = key + + override def getCurrentValue = value + + override def nextKeyValue = { + if (!processed) { + val fileIn = new PortableDataStream(split, context, index) + value = parseStream(fileIn) + fileIn.close() // if it has not been open yet, close does nothing + key = fileIn.getPath + processed = true + true + } else { + false + } + } + + /** + * Parse the stream (and close it afterwards) and return the value as in type T + * @param inStream the stream to be read in + * @return the data formatted as + */ + def parseStream(inStream: PortableDataStream): T +} + +/** + * Reads the record in directly as a stream for other objects to manipulate and handle + */ +private[spark] class StreamRecordReader( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends StreamBasedRecordReader[PortableDataStream](split, context, index) { + + def parseStream(inStream: PortableDataStream): PortableDataStream = inStream +} + +/** + * The format for the PortableDataStream files + */ +private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] { + override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = { + new CombineFileRecordReader[String, PortableDataStream]( + split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]) + } +} + +/** + * A class that allows DataStreams to be serialized and moved around by not creating them + * until they need to be read + * @note TaskAttemptContext is not serializable resulting in the confBytes construct + * @note CombineFileSplit is not serializable resulting in the splitBytes construct + */ +@Experimental +class PortableDataStream( + @transient isplit: CombineFileSplit, + @transient context: TaskAttemptContext, + index: Integer) + extends Serializable { + + // transient forces file to be reopened after being serialization + // it is also used for non-serializable classes + + @transient private var fileIn: DataInputStream = null + @transient private var isOpen = false + + private val confBytes = { + val baos = new ByteArrayOutputStream() + context.getConfiguration.write(new DataOutputStream(baos)) + baos.toByteArray + } + + private val splitBytes = { + val baos = new ByteArrayOutputStream() + isplit.write(new DataOutputStream(baos)) + baos.toByteArray + } + + @transient private lazy val split = { + val bais = new ByteArrayInputStream(splitBytes) + val nsplit = new CombineFileSplit() + nsplit.readFields(new DataInputStream(bais)) + nsplit + } + + @transient private lazy val conf = { + val bais = new ByteArrayInputStream(confBytes) + val nconf = new Configuration() + nconf.readFields(new DataInputStream(bais)) + nconf + } + /** + * Calculate the path name independently of opening the file + */ + @transient private lazy val path = { + val pathp = split.getPath(index) + pathp.toString + } + + /** + * Create a new DataInputStream from the split and context + */ + def open(): DataInputStream = { + if (!isOpen) { + val pathp = split.getPath(index) + val fs = pathp.getFileSystem(conf) + fileIn = fs.open(pathp) + isOpen = true + } + fileIn + } + + /** + * Read the file as a byte array + */ + def toArray(): Array[Byte] = { + open() + val innerBuffer = ByteStreams.toByteArray(fileIn) + close() + innerBuffer + } + + /** + * Close the file (if it is currently open) + */ + def close() = { + if (isOpen) { + try { + fileIn.close() + isOpen = false + } catch { + case ioe: java.io.IOException => // do nothing + } + } + } + + def getPath(): String = path +} + diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 4cb450577796a..d3601cca832b2 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -19,14 +19,13 @@ package org.apache.spark.input import scala.collection.JavaConversions._ +import org.apache.hadoop.conf.{Configuration, Configurable} import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.InputSplit import org.apache.hadoop.mapreduce.JobContext import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader -import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit /** * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for @@ -34,23 +33,31 @@ import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit * the value is the entire content of file. */ -private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] { +private[spark] class WholeTextFileInputFormat + extends CombineFileInputFormat[String, String] with Configurable { + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + private var conf: Configuration = _ + def setConf(c: Configuration) { + conf = c + } + def getConf: Configuration = conf + override def createRecordReader( split: InputSplit, context: TaskAttemptContext): RecordReader[String, String] = { - new CombineFileRecordReader[String, String]( - split.asInstanceOf[CombineFileSplit], - context, - classOf[WholeTextFileRecordReader]) + val reader = new WholeCombineFileRecordReader(split, context) + reader.setConf(conf) + reader } /** - * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API. + * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API, + * which is set through setMaxSplitSize */ - def setMaxSplitSize(context: JobContext, minPartitions: Int) { + def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context) val totalLen = files.map { file => if (file.isDir) 0L else file.getLen diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index 3564ab2e2a162..6d59b24eb0596 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -17,11 +17,13 @@ package org.apache.spark.input +import org.apache.hadoop.conf.{Configuration, Configurable} import com.google.common.io.{ByteStreams, Closeables} import org.apache.hadoop.io.Text +import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.InputSplit -import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader} import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext @@ -34,7 +36,13 @@ private[spark] class WholeTextFileRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) - extends RecordReader[String, String] { + extends RecordReader[String, String] with Configurable { + + private var conf: Configuration = _ + def setConf(c: Configuration) { + conf = c + } + def getConf: Configuration = conf private[this] val path = split.getPath(index) private[this] val fs = path.getFileSystem(context.getConfiguration) @@ -57,8 +65,16 @@ private[spark] class WholeTextFileRecordReader( override def nextKeyValue(): Boolean = { if (!processed) { + val conf = new Configuration + val factory = new CompressionCodecFactory(conf) + val codec = factory.getCodec(path) // infers from file ext. val fileIn = fs.open(path) - val innerBuffer = ByteStreams.toByteArray(fileIn) + val innerBuffer = if (codec != null) { + ByteStreams.toByteArray(codec.createInputStream(fileIn)) + } else { + ByteStreams.toByteArray(fileIn) + } + value = new Text(innerBuffer).toString Closeables.close(fileIn, false) processed = true @@ -68,3 +84,33 @@ private[spark] class WholeTextFileRecordReader( } } } + + +/** + * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file + * out in a key-value pair, where the key is the file path and the value is the entire content of + * the file. + */ +private[spark] class WholeCombineFileRecordReader( + split: InputSplit, + context: TaskAttemptContext) + extends CombineFileRecordReader[String, String]( + split.asInstanceOf[CombineFileSplit], + context, + classOf[WholeTextFileRecordReader] + ) with Configurable { + + private var conf: Configuration = _ + def setConf(c: Configuration) { + conf = c + } + def getConf: Configuration = conf + + override def initNextRecordReader(): Boolean = { + val r = super.initNextRecordReader() + if (r) { + this.curReader.asInstanceOf[WholeTextFileRecordReader].setConf(conf) + } + r + } +} diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala similarity index 79% rename from core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala rename to core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index 0c47afae54c8b..21b782edd2a9e 100644 --- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -15,15 +15,24 @@ * limitations under the License. */ -package org.apache.hadoop.mapred +package org.apache.spark.mapred -private[apache] +import java.lang.reflect.Modifier + +import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext} + +private[spark] trait SparkHadoopMapRedUtil { def newJobContext(conf: JobConf, jobId: JobID): JobContext = { val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID]) + // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private. + // Make it accessible if it's not in order to access it. + if (!Modifier.isPublic(ctor.getModifiers)) { + ctor.setAccessible(true) + } ctor.newInstance(conf, jobId).asInstanceOf[JobContext] } @@ -31,6 +40,10 @@ trait SparkHadoopMapRedUtil { val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext") val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) + // See above + if (!Modifier.isPublic(ctor.getModifiers)) { + ctor.setAccessible(true) + } ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] } diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala similarity index 96% rename from core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala rename to core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala index 1fca5729c6092..3340673f91156 100644 --- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala @@ -15,13 +15,14 @@ * limitations under the License. */ -package org.apache.hadoop.mapreduce +package org.apache.spark.mapreduce import java.lang.{Boolean => JBoolean, Integer => JInteger} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} -private[apache] +private[spark] trait SparkHadoopMapReduceUtil { def newJobContext(conf: Configuration, jobId: JobID): JobContext = { val klass = firstAvailableClass( diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index e0e91724271c8..1745d52c81923 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -17,20 +17,20 @@ package org.apache.spark.network -import org.apache.spark.storage.StorageLevel - +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.storage.{BlockId, StorageLevel} +private[spark] trait BlockDataManager { /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ - def getBlockData(blockId: String): Option[ManagedBuffer] + def getBlockData(blockId: BlockId): ManagedBuffer /** * Put the block locally, using the given storage level. */ - def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit + def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala deleted file mode 100644 index 34acaa563ca58..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.util.EventListener - - -/** - * Listener callback interface for [[BlockTransferService.fetchBlocks]]. - */ -trait BlockFetchingListener extends EventListener { - - /** - * Called once per successfully fetched block. - */ - def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit - - /** - * Called upon failures. For each failure, this is called only once (i.e. not once per block). - */ - def onBlockFetchFailure(exception: Throwable): Unit -} diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 84d991fa6808c..dcbda5a8515dd 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -17,13 +17,19 @@ package org.apache.spark.network -import scala.concurrent.{Await, Future} -import scala.concurrent.duration.Duration +import java.io.Closeable +import java.nio.ByteBuffer -import org.apache.spark.storage.StorageLevel +import scala.concurrent.{Promise, Await, Future} +import scala.concurrent.duration.Duration +import org.apache.spark.Logging +import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} +import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener} +import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel} -abstract class BlockTransferService { +private[spark] +abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { /** * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch @@ -34,7 +40,7 @@ abstract class BlockTransferService { /** * Tear down the transfer service. */ - def stop(): Unit + def close(): Unit /** * Port number the service is listening on, available only after [[init]] is invoked. @@ -50,17 +56,15 @@ abstract class BlockTransferService { * Fetch a sequence of blocks from a remote node asynchronously, * available only after [[init]] is invoked. * - * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block, - * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block). - * * Note that this API takes a sequence so the implementation can batch requests, and does not * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as * the data of a block is fetched, rather than waiting for all blocks to be fetched. */ - def fetchBlocks( - hostName: String, + override def fetchBlocks( + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit /** @@ -69,7 +73,8 @@ abstract class BlockTransferService { def uploadBlock( hostname: String, port: Int, - blockId: String, + execId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Future[Unit] @@ -78,40 +83,23 @@ abstract class BlockTransferService { * * It is also only available after [[init]] is invoked. */ - def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = { + def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = { // A monitor for the thread to wait on. - val lock = new Object - @volatile var result: Either[ManagedBuffer, Throwable] = null - fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener { - override def onBlockFetchFailure(exception: Throwable): Unit = { - lock.synchronized { - result = Right(exception) - lock.notify() + val result = Promise[ManagedBuffer]() + fetchBlocks(host, port, execId, Array(blockId), + new BlockFetchingListener { + override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + result.failure(exception) } - } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - lock.synchronized { - result = Left(data) - lock.notify() - } - } - }) - - // Sleep until result is no longer null - lock.synchronized { - while (result == null) { - try { - lock.wait() - } catch { - case e: InterruptedException => + override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + ret.flip() + result.success(new NioManagedBuffer(ret)) } - } - } + }) - result match { - case Left(data) => data - case Right(e) => throw e - } + Await.result(result.future, Duration.Inf) } /** @@ -123,9 +111,10 @@ abstract class BlockTransferService { def uploadBlockSync( hostname: String, port: Int, - blockId: String, + execId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel): Unit = { - Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf) + Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf) } } diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala deleted file mode 100644 index a4409181ec907..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network - -import java.io._ -import java.nio.ByteBuffer -import java.nio.channels.FileChannel -import java.nio.channels.FileChannel.MapMode - -import scala.util.Try - -import com.google.common.io.ByteStreams -import io.netty.buffer.{ByteBufInputStream, ByteBuf} - -import org.apache.spark.util.{ByteBufferInputStream, Utils} - - -/** - * This interface provides an immutable view for data in the form of bytes. The implementation - * should specify how the data is provided: - * - * - FileSegmentManagedBuffer: data backed by part of a file - * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer - * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf - */ -sealed abstract class ManagedBuffer { - // Note that all the methods are defined with parenthesis because their implementations can - // have side effects (io operations). - - /** Number of bytes of the data. */ - def size: Long - - /** - * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the - * returned ByteBuffer should not affect the content of this buffer. - */ - def nioByteBuffer(): ByteBuffer - - /** - * Exposes this buffer's data as an InputStream. The underlying implementation does not - * necessarily check for the length of bytes read, so the caller is responsible for making sure - * it does not go over the limit. - */ - def inputStream(): InputStream -} - - -/** - * A [[ManagedBuffer]] backed by a segment in a file - */ -final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long) - extends ManagedBuffer { - - override def size: Long = length - - override def nioByteBuffer(): ByteBuffer = { - var channel: FileChannel = null - try { - channel = new RandomAccessFile(file, "r").getChannel - channel.map(MapMode.READ_ONLY, offset, length) - } catch { - case e: IOException => - Try(channel.size).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - } finally { - if (channel != null) { - Utils.tryLog(channel.close()) - } - } - } - - override def inputStream(): InputStream = { - var is: FileInputStream = null - try { - is = new FileInputStream(file) - is.skip(offset) - ByteStreams.limit(is, length) - } catch { - case e: IOException => - if (is != null) { - Utils.tryLog(is.close()) - } - Try(file.length).toOption match { - case Some(fileLen) => - throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) - case None => - throw new IOException(s"Error in opening $this", e) - } - case e: Throwable => - if (is != null) { - Utils.tryLog(is.close()) - } - throw e - } - } - - override def toString: String = s"${getClass.getName}($file, $offset, $length)" -} - - -/** - * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]]. - */ -final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer { - - override def size: Long = buf.remaining() - - override def nioByteBuffer() = buf.duplicate() - - override def inputStream() = new ByteBufferInputStream(buf) -} - - -/** - * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]]. - */ -final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer { - - override def size: Long = buf.readableBytes() - - override def nioByteBuffer() = buf.nioBuffer() - - override def inputStream() = new ByteBufInputStream(buf) - - // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it. - def release(): Unit = buf.release() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala new file mode 100644 index 0000000000000..b089da8596e2b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.nio.ByteBuffer + +import scala.collection.JavaConversions._ + +import org.apache.spark.Logging +import org.apache.spark.network.BlockDataManager +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} +import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} +import org.apache.spark.serializer.Serializer +import org.apache.spark.storage.{BlockId, StorageLevel} + +/** + * Serves requests to open blocks by simply registering one chunk per block requested. + * Handles opening and uploading arbitrary BlockManager blocks. + * + * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk + * is equivalent to one Spark-level shuffle block. + */ +class NettyBlockRpcServer( + serializer: Serializer, + blockManager: BlockDataManager) + extends RpcHandler with Logging { + + private val streamManager = new OneForOneStreamManager() + + override def receive( + client: TransportClient, + messageBytes: Array[Byte], + responseContext: RpcResponseCallback): Unit = { + val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes) + logTrace(s"Received request: $message") + + message match { + case openBlocks: OpenBlocks => + val blocks: Seq[ManagedBuffer] = + openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) + val streamId = streamManager.registerStream(blocks.iterator) + logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") + responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray) + + case uploadBlock: UploadBlock => + // StorageLevel is serialized as bytes using our JavaSerializer. + val level: StorageLevel = + serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata)) + val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) + blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level) + responseContext.onSuccess(new Array[Byte](0)) + } + } + + override def getStreamManager(): StreamManager = streamManager +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala new file mode 100644 index 0000000000000..0027cbb0ff1fb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import scala.collection.JavaConversions._ +import scala.concurrent.{Future, Promise} + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.network._ +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap} +import org.apache.spark.network.server._ +import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.protocol.UploadBlock +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.util.Utils + +/** + * A BlockTransferService that uses Netty to fetch a set of blocks at at time. + */ +class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int) + extends BlockTransferService { + + // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. + private val serializer = new JavaSerializer(conf) + private val authEnabled = securityManager.isAuthenticationEnabled() + private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores) + + private[this] var transportContext: TransportContext = _ + private[this] var server: TransportServer = _ + private[this] var clientFactory: TransportClientFactory = _ + private[this] var appId: String = _ + + override def init(blockDataManager: BlockDataManager): Unit = { + val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = { + val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager) + if (!authEnabled) { + (nettyRpcHandler, None) + } else { + (new SaslRpcHandler(nettyRpcHandler, securityManager), + Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager))) + } + } + transportContext = new TransportContext(transportConf, rpcHandler) + clientFactory = transportContext.createClientFactory(bootstrap.toList) + server = transportContext.createServer() + appId = conf.getAppId + logInfo("Server created on " + server.getPort) + } + + override def fetchBlocks( + host: String, + port: Int, + execId: String, + blockIds: Array[String], + listener: BlockFetchingListener): Unit = { + logTrace(s"Fetch blocks from $host:$port (executor id $execId)") + try { + val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { + override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { + val client = clientFactory.createClient(host, port) + new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start() + } + } + + val maxRetries = transportConf.maxIORetries() + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start() + } else { + blockFetchStarter.createAndStart(blockIds, listener) + } + } catch { + case e: Exception => + logError("Exception while beginning fetchBlocks", e) + blockIds.foreach(listener.onBlockFetchFailure(_, e)) + } + } + + override def hostName: String = Utils.localHostName() + + override def port: Int = server.getPort + + override def uploadBlock( + hostname: String, + port: Int, + execId: String, + blockId: BlockId, + blockData: ManagedBuffer, + level: StorageLevel): Future[Unit] = { + val result = Promise[Unit]() + val client = clientFactory.createClient(hostname, port) + + // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded + // using our binary protocol. + val levelBytes = serializer.newInstance().serialize(level).array() + + // Convert or copy nio buffer into array in order to serialize it. + val nioBuffer = blockData.nioByteBuffer() + val array = if (nioBuffer.hasArray) { + nioBuffer.array() + } else { + val data = new Array[Byte](nioBuffer.remaining()) + nioBuffer.get(data) + data + } + + client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray, + new RpcResponseCallback { + override def onSuccess(response: Array[Byte]): Unit = { + logTrace(s"Successfully uploaded block $blockId") + result.success() + } + override def onFailure(e: Throwable): Unit = { + logError(s"Error while uploading block $blockId", e) + result.failure(e) + } + }) + + result.future + } + + override def close(): Unit = { + server.close() + clientFactory.close() + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala deleted file mode 100644 index b5870152c5a64..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import org.apache.spark.SparkConf - -/** - * A central location that tracks all the settings we exposed to users. - */ -private[spark] -class NettyConfig(conf: SparkConf) { - - /** Port the server listens on. Default to a random port. */ - private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0) - - /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */ - private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase - - /** Connect timeout in secs. Default 60 secs. */ - private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000 - - /** - * Percentage of the desired amount of time spent for I/O in the child event loops. - * Only applicable in nio and epoll. - */ - private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80) - - /** Requested maximum length of the queue of incoming connections. */ - private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt) - - /** - * Receive buffer size (SO_RCVBUF). - * Note: the optimal size for receive buffer and send buffer should be - * latency * network_bandwidth. - * Assuming latency = 1ms, network_bandwidth = 10Gbps - * buffer size should be ~ 1.25MB - */ - private[netty] val receiveBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) - - /** Send buffer size (SO_SNDBUF). */ - private[netty] val sendBuf: Option[Int] = - conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt) -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala deleted file mode 100644 index 0d7695072a7b1..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty - -import org.apache.spark.storage.{BlockId, FileSegment} - -trait PathResolver { - /** Get the file segment in which the given block resides. */ - def getBlockLocation(blockId: BlockId): FileSegment -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala new file mode 100644 index 0000000000000..cef203006d685 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import org.apache.spark.SparkConf +import org.apache.spark.network.util.{TransportConf, ConfigProvider} + +/** + * Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor, + * Driver, or a standalone shuffle service) into a TransportConf with details on our environment + * like the number of cores that are allocated to this JVM. + */ +object SparkTransportConf { + /** + * Specifies an upper bound on the number of Netty threads that Spark requires by default. + * In practice, only 2-4 cores should be required to transfer roughly 10 Gb/s, and each core + * that we use will have an initial overhead of roughly 32 MB of off-heap memory, which comes + * at a premium. + * + * Thus, this value should still retain maximum throughput and reduce wasted off-heap memory + * allocation. It can be overridden by setting the number of serverThreads and clientThreads + * manually in Spark's configuration. + */ + private val MAX_DEFAULT_NETTY_THREADS = 8 + + /** + * Utility for creating a [[TransportConf]] from a [[SparkConf]]. + * @param numUsableCores if nonzero, this will restrict the server and client threads to only + * use the given number of cores, rather than all of the machine's cores. + * This restriction will only occur if these properties are not already set. + */ + def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = { + val conf = _conf.clone + + // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily + // assuming we have all the machine's cores). + // NB: Only set if serverThreads/clientThreads not already set. + val numThreads = defaultNumThreads(numUsableCores) + conf.set("spark.shuffle.io.serverThreads", + conf.get("spark.shuffle.io.serverThreads", numThreads.toString)) + conf.set("spark.shuffle.io.clientThreads", + conf.get("spark.shuffle.io.clientThreads", numThreads.toString)) + + new TransportConf(new ConfigProvider { + override def get(name: String): String = conf.get(name) + }) + } + + /** + * Returns the default number of threads for both the Netty client and server thread pools. + * If numUsableCores is 0, we will use Runtime get an approximate number of available cores. + */ + private def defaultNumThreads(numUsableCores: Int): Int = { + val availableCores = + if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() + math.min(availableCores, MAX_DEFAULT_NETTY_THREADS) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala deleted file mode 100644 index e28219dd7745b..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.util.EventListener - - -trait BlockClientListener extends EventListener { - - def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit - - def onFetchFailure(blockId: String, errorMsg: String): Unit - -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala deleted file mode 100644 index 5aea7ba2f3673..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala +++ /dev/null @@ -1,132 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.util.concurrent.TimeoutException - -import io.netty.bootstrap.Bootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.socket.SocketChannel -import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.handler.codec.LengthFieldBasedFrameDecoder -import io.netty.handler.codec.string.StringEncoder -import io.netty.util.CharsetUtil - -import org.apache.spark.Logging - -/** - * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]]. - * Use [[BlockFetchingClientFactory]] to instantiate this client. - * - * The constructor blocks until a connection is successfully established. - * - * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol. - * - * Concurrency: thread safe and can be called from multiple threads. - */ -@throws[TimeoutException] -private[spark] -class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int) - extends Logging { - - private val handler = new BlockFetchingClientHandler - - /** Netty Bootstrap for creating the TCP connection. */ - private val bootstrap: Bootstrap = { - val b = new Bootstrap - b.group(factory.workerGroup) - .channel(factory.socketChannelClass) - // Use pooled buffers to reduce temporary buffer allocation - .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - // Disable Nagle's Algorithm since we don't want packets to wait - .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE) - .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE) - .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs) - - b.handler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8)) - // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4 - .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4)) - .addLast("handler", handler) - } - }) - b - } - - /** Netty ChannelFuture for the connection. */ - private val cf: ChannelFuture = bootstrap.connect(hostname, port) - if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) { - throw new TimeoutException( - s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)") - } - - /** - * Ask the remote server for a sequence of blocks, and execute the callback. - * - * Note that this is asynchronous and returns immediately. Upstream caller should throttle the - * rate of fetching; otherwise we could run out of memory. - * - * @param blockIds sequence of block ids to fetch. - * @param listener callback to fire on fetch success / failure. - */ - def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = { - // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline. - // It's also best to limit the number of "flush" calls since it requires system calls. - // Let's concatenate the string and then call writeAndFlush once. - // This is also why this implementation might be more efficient than multiple, separate - // fetch block calls. - var startTime: Long = 0 - logTrace { - startTime = System.nanoTime - s"Sending request $blockIds to $hostname:$port" - } - - blockIds.foreach { blockId => - handler.addRequest(blockId, listener) - } - - val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n") - writeFuture.addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture): Unit = { - if (future.isSuccess) { - logTrace { - val timeTaken = (System.nanoTime - startTime).toDouble / 1000000 - s"Sending request $blockIds to $hostname:$port took $timeTaken ms" - } - } else { - // Fail all blocks. - val errorMsg = - s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}" - logError(errorMsg, future.cause) - blockIds.foreach { blockId => - listener.onFetchFailure(blockId, errorMsg) - handler.removeRequest(blockId) - } - } - } - }) - } - - def waitForClose(): Unit = { - cf.channel().closeFuture().sync() - } - - def close(): Unit = cf.channel().close() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala deleted file mode 100644 index 2b28402c52b49..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.nio.NioSocketChannel -import io.netty.channel.socket.oio.OioSocketChannel -import io.netty.channel.{EventLoopGroup, Channel} - -import org.apache.spark.SparkConf -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.util.Utils - -/** - * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses - * the worker thread pool for Netty. - * - * Concurrency: createClient is safe to be called from multiple threads concurrently. - */ -private[spark] -class BlockFetchingClientFactory(val conf: NettyConfig) { - - def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf)) - - /** A thread factory so the threads are named (for debugging). */ - val threadFactory = Utils.namedThreadFactory("spark-shuffle-client") - - /** The following two are instantiated by the [[init]] method, depending ioMode. */ - var socketChannelClass: Class[_ <: Channel] = _ - var workerGroup: EventLoopGroup = _ - - init() - - /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */ - private def init(): Unit = { - def initOio(): Unit = { - socketChannelClass = classOf[OioSocketChannel] - workerGroup = new OioEventLoopGroup(0, threadFactory) - } - def initNio(): Unit = { - socketChannelClass = classOf[NioSocketChannel] - workerGroup = new NioEventLoopGroup(0, threadFactory) - } - def initEpoll(): Unit = { - socketChannelClass = classOf[EpollSocketChannel] - workerGroup = new EpollEventLoopGroup(0, threadFactory) - } - - conf.ioMode match { - case "nio" => initNio() - case "oio" => initOio() - case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } - } - } - - /** - * Create a new BlockFetchingClient connecting to the given remote host / port. - * - * This blocks until a connection is successfully established. - * - * Concurrency: This method is safe to call from multiple threads. - */ - def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = { - new BlockFetchingClient(this, remoteHost, remotePort) - } - - def stop(): Unit = { - if (workerGroup != null) { - workerGroup.shutdownGracefully() - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala deleted file mode 100644 index 83265b164299d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala +++ /dev/null @@ -1,103 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import io.netty.buffer.ByteBuf -import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} - -import org.apache.spark.Logging - - -/** - * Handler that processes server responses. It uses the protocol documented in - * [[org.apache.spark.network.netty.server.BlockServer]]. - * - * Concurrency: thread safe and can be called from multiple threads. - */ -private[client] -class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging { - - /** Tracks the list of outstanding requests and their listeners on success/failure. */ - private val outstandingRequests = java.util.Collections.synchronizedMap { - new java.util.HashMap[String, BlockClientListener] - } - - def addRequest(blockId: String, listener: BlockClientListener): Unit = { - outstandingRequests.put(blockId, listener) - } - - def removeRequest(blockId: String): Unit = { - outstandingRequests.remove(blockId) - } - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}" - logError(errorMsg, cause) - - // Fire the failure callback for all outstanding blocks - outstandingRequests.synchronized { - val iter = outstandingRequests.entrySet().iterator() - while (iter.hasNext) { - val entry = iter.next() - entry.getValue.onFetchFailure(entry.getKey, errorMsg) - } - outstandingRequests.clear() - } - - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) { - val totalLen = in.readInt() - val blockIdLen = in.readInt() - val blockIdBytes = new Array[Byte](math.abs(blockIdLen)) - in.readBytes(blockIdBytes) - val blockId = new String(blockIdBytes) - val blockSize = totalLen - math.abs(blockIdLen) - 4 - - def server = ctx.channel.remoteAddress.toString - - // blockIdLen is negative when it is an error message. - if (blockIdLen < 0) { - val errorMessageBytes = new Array[Byte](blockSize) - in.readBytes(errorMessageBytes) - val errorMsg = new String(errorMessageBytes) - logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchFailure(blockId, errorMsg) - } - } else { - logTrace(s"Received block $blockId ($blockSize B) from $server") - - val listener = outstandingRequests.get(blockId) - if (listener == null) { - // Ignore callback - logWarning(s"Got a response for block $blockId but it is not in our outstanding requests") - } else { - outstandingRequests.remove(blockId) - listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in)) - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala deleted file mode 100644 index 9740ee64d1f2d..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -/** - * A simple iterator that lazily initializes the underlying iterator. - * - * The use case is that sometimes we might have many iterators open at the same time, and each of - * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer). - * This could lead to too many buffers open. If this iterator is used, we lazily initialize those - * buffers. - */ -private[spark] -class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] { - - lazy val proxy = createIterator - - override def hasNext: Boolean = { - val gotNext = proxy.hasNext - if (!gotNext) { - close() - } - gotNext - } - - override def next(): Any = proxy.next() - - def close(): Unit = Unit -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala deleted file mode 100644 index ea1abf5eccc26..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.client - -import java.io.InputStream -import java.nio.ByteBuffer - -import io.netty.buffer.{ByteBuf, ByteBufInputStream} - - -/** - * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty. - * This is a Scala value class. - * - * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of - * reference by the retain method and release method. - */ -private[spark] -class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal { - - /** Return the nio ByteBuffer view of the underlying buffer. */ - def byteBuffer(): ByteBuffer = underlying.nioBuffer - - /** Creates a new input stream that starts from the current position of the buffer. */ - def inputStream(): InputStream = new ByteBufInputStream(underlying) - - /** Increment the reference counter by one. */ - def retain(): Unit = underlying.retain() - - /** Decrement the reference counter by one and release the buffer if the ref count is 0. */ - def release(): Unit = underlying.release() -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala deleted file mode 100644 index 162e9cc6828d4..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -/** - * Header describing a block. This is used only in the server pipeline. - * - * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it. - * - * @param blockSize length of the block content, excluding the length itself. - * If positive, this is the header for a block (not part of the header). - * If negative, this is the header and content for an error message. - * @param blockId block id - * @param error some error message from reading the block - */ -private[server] -class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None) diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala deleted file mode 100644 index 8e4dda4ef8595..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.buffer.ByteBuf -import io.netty.channel.ChannelHandlerContext -import io.netty.handler.codec.MessageToByteEncoder - -/** - * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol. - */ -private[server] -class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] { - override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = { - // message = message length (4 bytes) + block id length (4 bytes) + block id + block data - // message length = block id length (4 bytes) + size of block id + size of block data - val blockIdBytes = msg.blockId.getBytes - msg.error match { - case Some(errorMsg) => - val errorBytes = errorMsg.getBytes - out.writeInt(4 + blockIdBytes.length + errorBytes.size) - out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors - out.writeBytes(blockIdBytes) // next is blockId itself - out.writeBytes(errorBytes) // error message - case None => - out.writeInt(4 + blockIdBytes.length + msg.blockSize) - out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length - out.writeBytes(blockIdBytes) // next is blockId itself - // msg of size blockSize will be written by ServerHandler - } - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala deleted file mode 100644 index 7b2f9a8d4dfd0..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.net.InetSocketAddress - -import io.netty.bootstrap.ServerBootstrap -import io.netty.buffer.PooledByteBufAllocator -import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption} -import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel} -import io.netty.channel.nio.NioEventLoopGroup -import io.netty.channel.oio.OioEventLoopGroup -import io.netty.channel.socket.SocketChannel -import io.netty.channel.socket.nio.NioServerSocketChannel -import io.netty.channel.socket.oio.OioServerSocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil - -import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.network.netty.NettyConfig -import org.apache.spark.storage.BlockDataProvider -import org.apache.spark.util.Utils - - -/** - * Server for serving Spark data blocks. - * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]]. - * - * Protocol for requesting blocks (client to server): - * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n" - * - * Protocol for sending blocks (server to client): - * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data. - * - * frame-length should not include the length of itself. - * If block-id-length is negative, then this is an error message rather than block-data. The real - * length is the absolute value of the frame-length. - * - */ -private[spark] -class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging { - - def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = { - this(new NettyConfig(sparkConf), dataProvider) - } - - def port: Int = _port - - def hostName: String = _hostName - - private var _port: Int = conf.serverPort - private var _hostName: String = "" - private var bootstrap: ServerBootstrap = _ - private var channelFuture: ChannelFuture = _ - - init() - - /** Initialize the server. */ - private def init(): Unit = { - bootstrap = new ServerBootstrap - val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss") - val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker") - - // Use only one thread to accept connections, and 2 * num_cores for worker. - def initNio(): Unit = { - val bossGroup = new NioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new NioEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) - bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel]) - } - def initOio(): Unit = { - val bossGroup = new OioEventLoopGroup(1, bossThreadFactory) - val workerGroup = new OioEventLoopGroup(0, workerThreadFactory) - bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel]) - } - def initEpoll(): Unit = { - val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory) - val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory) - workerGroup.setIoRatio(conf.ioRatio) - bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel]) - } - - conf.ioMode match { - case "nio" => initNio() - case "oio" => initOio() - case "epoll" => initEpoll() - case "auto" => - // For auto mode, first try epoll (only available on Linux), then nio. - try { - initEpoll() - } catch { - // TODO: Should we log the throwable? But that always happen on non-Linux systems. - // Perhaps the right thing to do is to check whether the system is Linux, and then only - // call initEpoll on Linux. - case e: Throwable => initNio() - } - } - - // Use pooled buffers to reduce temporary buffer allocation - bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT) - - // Various (advanced) user-configured settings. - conf.backLog.foreach { backLog => - bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog) - } - conf.receiveBuf.foreach { receiveBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf) - } - conf.sendBuf.foreach { sendBuf => - bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf) - } - - bootstrap.childHandler(new ChannelInitializer[SocketChannel] { - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } - }) - - channelFuture = bootstrap.bind(new InetSocketAddress(_port)) - channelFuture.sync() - - val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress] - _port = addr.getPort - _hostName = addr.getHostName - } - - /** Shutdown the server. */ - def stop(): Unit = { - if (channelFuture != null) { - channelFuture.channel().close().awaitUninterruptibly() - channelFuture = null - } - if (bootstrap != null && bootstrap.group() != null) { - bootstrap.group().shutdownGracefully() - } - if (bootstrap != null && bootstrap.childGroup() != null) { - bootstrap.childGroup().shutdownGracefully() - } - bootstrap = null - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala deleted file mode 100644 index cc70bd0c5c477..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import io.netty.channel.ChannelInitializer -import io.netty.channel.socket.SocketChannel -import io.netty.handler.codec.LineBasedFrameDecoder -import io.netty.handler.codec.string.StringDecoder -import io.netty.util.CharsetUtil -import org.apache.spark.storage.BlockDataProvider - - -/** Channel initializer that sets up the pipeline for the BlockServer. */ -private[netty] -class BlockServerChannelInitializer(dataProvider: BlockDataProvider) - extends ChannelInitializer[SocketChannel] { - - override def initChannel(ch: SocketChannel): Unit = { - ch.pipeline - .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024 - .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8)) - .addLast("blockHeaderEncoder", new BlockHeaderEncoder) - .addLast("handler", new BlockServerHandler(dataProvider)) - } -} diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala deleted file mode 100644 index 40dd5e5d1a2ac..0000000000000 --- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.netty.server - -import java.io.FileInputStream -import java.nio.ByteBuffer -import java.nio.channels.FileChannel - -import io.netty.buffer.Unpooled -import io.netty.channel._ - -import org.apache.spark.Logging -import org.apache.spark.storage.{FileSegment, BlockDataProvider} - - -/** - * A handler that processes requests from clients and writes block data back. - * - * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first - * so channelRead0 is called once per line (i.e. per block id). - */ -private[server] -class BlockServerHandler(dataProvider: BlockDataProvider) - extends SimpleChannelInboundHandler[String] with Logging { - - override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { - logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause) - ctx.close() - } - - override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = { - def client = ctx.channel.remoteAddress.toString - - // A helper function to send error message back to the client. - def respondWithError(error: String): Unit = { - ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener( - new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (!future.isSuccess) { - // TODO: Maybe log the success case as well. - logError(s"Error sending error back to $client", future.cause) - ctx.close() - } - } - } - ) - } - - def writeFileSegment(segment: FileSegment): Unit = { - // Send error message back if the block is too large. Even though we are capable of sending - // large (2G+) blocks, the receiving end cannot handle it so let's fail fast. - // Once we fixed the receiving end to be able to process large blocks, this should be removed. - // Also make sure we update BlockHeaderEncoder to support length > 2G. - - // See [[BlockHeaderEncoder]] for the way length is encoded. - if (segment.length + blockId.length + 4 > Int.MaxValue) { - respondWithError(s"Block $blockId size ($segment.length) greater than 2G") - return - } - - var fileChannel: FileChannel = null - try { - fileChannel = new FileInputStream(segment.file).getChannel - } catch { - case e: Exception => - logError( - s"Error opening channel for $blockId in ${segment.file} for request from $client", e) - respondWithError(e.getMessage) - } - - // Found the block. Send it back. - if (fileChannel != null) { - // Write the header and block data. In the case of failures, the listener on the block data - // write should close the connection. - ctx.write(new BlockHeader(segment.length.toInt, blockId)) - - val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length) - ctx.writeAndFlush(region).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${segment.length} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - } - - def writeByteBuffer(buf: ByteBuffer): Unit = { - ctx.write(new BlockHeader(buf.remaining, blockId)) - ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener { - override def operationComplete(future: ChannelFuture) { - if (future.isSuccess) { - logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client") - } else { - logError(s"Error sending block $blockId to $client; closing connection", future.cause) - ctx.close() - } - } - }) - } - - logTrace(s"Received request from $client to fetch block $blockId") - - var blockData: Either[FileSegment, ByteBuffer] = null - - // First make sure we can find the block. If not, send error back to the user. - try { - blockData = dataProvider.getBlockData(blockId) - } catch { - case e: Exception => - logError(s"Error opening block $blockId for request from $client", e) - respondWithError(e.getMessage) - return - } - - blockData match { - case Left(segment) => writeFileSegment(segment) - case Right(buf) => writeByteBuffer(buf) - } - - } // end of channelRead0 -} diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala index f368209980f93..c2d9578be7ebb 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala @@ -20,11 +20,15 @@ package org.apache.spark.network.nio import java.net._ import java.nio._ import java.nio.channels._ +import java.util.concurrent.ConcurrentLinkedQueue import java.util.LinkedList -import org.apache.spark._ - +import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.control.NonFatal + +import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} private[nio] abstract class Connection(val channel: SocketChannel, val selector: Selector, @@ -51,7 +55,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, @volatile private var closed = false var onCloseCallback: Connection => Unit = null - var onExceptionCallback: (Connection, Exception) => Unit = null + val onExceptionCallbacks = new ConcurrentLinkedQueue[(Connection, Throwable) => Unit] var onKeyInterestChangeCallback: (Connection, Int) => Unit = null val remoteAddress = getRemoteAddress() @@ -130,20 +134,24 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector, onCloseCallback = callback } - def onException(callback: (Connection, Exception) => Unit) { - onExceptionCallback = callback + def onException(callback: (Connection, Throwable) => Unit) { + onExceptionCallbacks.add(callback) } def onKeyInterestChange(callback: (Connection, Int) => Unit) { onKeyInterestChangeCallback = callback } - def callOnExceptionCallback(e: Exception) { - if (onExceptionCallback != null) { - onExceptionCallback(this, e) - } else { - logError("Error in connection to " + getRemoteConnectionManagerId() + - " and OnExceptionCallback not registered", e) + def callOnExceptionCallbacks(e: Throwable) { + onExceptionCallbacks foreach { + callback => + try { + callback(this, e) + } catch { + case NonFatal(e) => { + logWarning("Ignored error in onExceptionCallback", e) + } + } } } @@ -323,7 +331,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logError("Error connecting to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } } @@ -348,7 +356,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error finishing connection to " + address, e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) } } true @@ -393,7 +401,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } catch { case e: Exception => { logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } @@ -420,7 +428,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, case e: Exception => logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() } @@ -577,7 +585,7 @@ private[spark] class ReceivingConnection( } catch { case e: Exception => { logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e) - callOnExceptionCallback(e) + callOnExceptionCallbacks(e) close() return false } diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 01cd27a907eea..df4b085d2251e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -18,22 +18,28 @@ package org.apache.spark.network.nio import java.io.IOException +import java.lang.ref.WeakReference import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ import java.util.concurrent.atomic.AtomicInteger import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit} -import java.util.{Timer, TimerTask} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue} import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future, Promise} import scala.language.postfixOps +import com.google.common.base.Charsets.UTF_8 +import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} + import org.apache.spark._ +import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer} import org.apache.spark.util.Utils +import scala.util.Try +import scala.util.control.NonFatal private[nio] class ConnectionManager( port: Int, @@ -51,19 +57,29 @@ private[nio] class ConnectionManager( class MessageStatus( val message: Message, val connectionManagerId: ConnectionManagerId, - completionHandler: MessageStatus => Unit) { + completionHandler: Try[Message] => Unit) { + + def success(ackMessage: Message) { + if (ackMessage == null) { + failure(new NullPointerException) + } + else { + completionHandler(scala.util.Success(ackMessage)) + } + } - /** This is non-None if message has been ack'd */ - var ackMessage: Option[Message] = None + def failWithoutAck() { + completionHandler(scala.util.Failure(new IOException("Failed without being ACK'd"))) + } - def markDone(ackMessage: Option[Message]) { - this.ackMessage = ackMessage - completionHandler(this) + def failure(e: Throwable) { + completionHandler(scala.util.Failure(e)) } } private val selector = SelectorProvider.provider.openSelector() - private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true) + private val ackTimeoutMonitor = + new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor")) private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60) @@ -72,14 +88,32 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.handler.threads.max", 60), conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-message-executor")) + Utils.namedThreadFactory("handle-message-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleMessageExecutor is not handled properly", t) + } + } + + } private val handleReadWriteExecutor = new ThreadPoolExecutor( conf.getInt("spark.core.connection.io.threads.min", 4), conf.getInt("spark.core.connection.io.threads.max", 32), conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-read-write-executor")) + Utils.namedThreadFactory("handle-read-write-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleReadWriteExecutor is not handled properly", t) + } + } + + } // Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : // which should be executed asap @@ -88,7 +122,16 @@ private[nio] class ConnectionManager( conf.getInt("spark.core.connection.connect.threads.max", 8), conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS, new LinkedBlockingDeque[Runnable](), - Utils.namedThreadFactory("handle-connect-executor")) + Utils.namedThreadFactory("handle-connect-executor")) { + + override def afterExecute(r: Runnable, t: Throwable): Unit = { + super.afterExecute(r, t) + if (t != null && NonFatal(t)) { + logError("Error in handleConnectExecutor is not handled properly", t) + } + } + + } private val serverChannel = ServerSocketChannel.open() // used to track the SendingConnections waiting to do SASL negotiation @@ -98,7 +141,10 @@ private[nio] class ConnectionManager( new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection] - private val messageStatuses = new HashMap[Int, MessageStatus] + // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this + // map when messages are sent and are removed when acknowledgement messages are received or when + // acknowledgement timeouts expire + private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus] private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)] private val registerRequests = new SynchronizedQueue[SendingConnection] @@ -153,17 +199,24 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.write() - } finally { - writeRunnableStarted.synchronized { - writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() - if (needReregister && conn.changeInterestForWrite()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.write() + } finally { + writeRunnableStarted.synchronized { + writeRunnableStarted -= key + val needReregister = register || conn.resetForceReregister() + if (needReregister && conn.changeInterestForWrite()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when writing to " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -187,16 +240,23 @@ private[nio] class ConnectionManager( } handleReadWriteExecutor.execute(new Runnable { override def run() { - var register: Boolean = false try { - register = conn.read() - } finally { - readRunnableStarted.synchronized { - readRunnableStarted -= key - if (register && conn.changeInterestForRead()) { - conn.registerInterest() + var register: Boolean = false + try { + register = conn.read() + } finally { + readRunnableStarted.synchronized { + readRunnableStarted -= key + if (register && conn.changeInterestForRead()) { + conn.registerInterest() + } } } + } catch { + case NonFatal(e) => { + logError("Error when reading from " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } } } ) @@ -213,19 +273,25 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { + try { + var tries: Int = 10 + while (tries >= 0) { + if (conn.finishConnect(false)) return + // Sleep ? + Thread.sleep(1) + tries -= 1 + } - var tries: Int = 10 - while (tries >= 0) { - if (conn.finishConnect(false)) return - // Sleep ? - Thread.sleep(1) - tries -= 1 + // fallback to previous behavior : we should not really come here since this method was + // triggered since channel became connectable : but at times, the first finishConnect need + // not succeed : hence the loop to retry a few 'times'. + conn.finishConnect(true) + } catch { + case NonFatal(e) => { + logError("Error when finishConnect for " + conn.getRemoteConnectionManagerId(), e) + conn.callOnExceptionCallbacks(e) + } } - - // fallback to previous behavior : we should not really come here since this method was - // triggered since channel became connectable : but at times, the first finishConnect need - // not succeed : hence the loop to retry a few 'times'. - conn.finishConnect(true) } } ) } @@ -246,16 +312,16 @@ private[nio] class ConnectionManager( handleConnectExecutor.execute(new Runnable { override def run() { try { - conn.callOnExceptionCallback(e) + conn.callOnExceptionCallbacks(e) } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } try { conn.close() } catch { // ignore exceptions - case e: Exception => logDebug("Ignoring exception", e) + case NonFatal(e) => logDebug("Ignoring exception", e) } } }) @@ -448,7 +514,7 @@ private[nio] class ConnectionManager( messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId) .foreach(status => { logInfo("Notifying " + status) - status.markDone(None) + status.failWithoutAck() }) messageStatuses.retain((i, status) => { @@ -477,7 +543,7 @@ private[nio] class ConnectionManager( for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) { logInfo("Notifying " + s) - s.markDone(None) + s.failWithoutAck() } messageStatuses.retain((i, status) => { @@ -492,7 +558,7 @@ private[nio] class ConnectionManager( } } - def handleConnectionError(connection: Connection, e: Exception) { + def handleConnectionError(connection: Connection, e: Throwable) { logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId()) removeConnection(connection) @@ -510,9 +576,17 @@ private[nio] class ConnectionManager( val runnable = new Runnable() { val creationTime = System.currentTimeMillis def run() { - logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") - handleMessage(connectionManagerId, message, connection) - logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + try { + logDebug("Handler thread delay is " + (System.currentTimeMillis - creationTime) + " ms") + handleMessage(connectionManagerId, message, connection) + logDebug("Handling delay is " + (System.currentTimeMillis - creationTime) + " ms") + } catch { + case NonFatal(e) => { + logError("Error when handling messages from " + + connection.getRemoteConnectionManagerId(), e) + connection.callOnExceptionCallbacks(e) + } + } } } handleMessageExecutor.execute(runnable) @@ -532,7 +606,7 @@ private[nio] class ConnectionManager( } else { var replyToken : Array[Byte] = null try { - replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken) + replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken) if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId) connectionsAwaitingSasl -= waitingConn.connectionId @@ -566,7 +640,7 @@ private[nio] class ConnectionManager( connection.synchronized { if (connection.sparkSaslServer == null) { logDebug("Creating sasl Server") - connection.sparkSaslServer = new SparkSaslServer(securityManager) + connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager) } } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) @@ -651,7 +725,7 @@ private[nio] class ConnectionManager( messageStatuses.get(bufferMessage.ackId) match { case Some(status) => { messageStatuses -= bufferMessage.ackId - status.markDone(Some(message)) + status.success(message) } case None => { /** @@ -691,9 +765,7 @@ private[nio] class ConnectionManager( } catch { case e: Exception => { logError(s"Exception was thrown while processing message", e) - val m = Message.createBufferMessage(bufferMessage.id) - m.hasError = true - ackMessage = Some(m) + ackMessage = Some(Message.createErrorMessage(e, bufferMessage.id)) } } finally { sendMessage(connectionManagerId, ackMessage.getOrElse { @@ -712,7 +784,7 @@ private[nio] class ConnectionManager( if (!conn.isSaslComplete()) { conn.synchronized { if (conn.sparkSaslClient == null) { - conn.sparkSaslClient = new SparkSaslClient(securityManager) + conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager) var firstResponse: Array[Byte] = null try { firstResponse = conn.sparkSaslClient.firstToken() @@ -770,6 +842,12 @@ private[nio] class ConnectionManager( val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue) val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId, newConnectionId, securityManager) + newConnection.onException { + case (conn, e) => { + logError("Exception while sending message.", e) + reportSendingMessageFailure(message.id, e) + } + } logTrace("creating new sending connection: " + newConnectionId) registerRequests.enqueue(newConnection) @@ -782,13 +860,36 @@ private[nio] class ConnectionManager( "connectionid: " + connection.connectionId) if (authEnabled) { - checkSendAuthFirst(connectionManagerId, connection) + try { + checkSendAuthFirst(connectionManagerId, connection) + } catch { + case NonFatal(e) => { + reportSendingMessageFailure(message.id, e) + } + } } logDebug("Sending [" + message + "] to [" + connectionManagerId + "]") connection.send(message) wakeupSelector() } + private def reportSendingMessageFailure(messageId: Int, e: Throwable): Unit = { + // need to tell sender it failed + messageStatuses.synchronized { + val s = messageStatuses.get(messageId) + s match { + case Some(msgStatus) => { + messageStatuses -= messageId + logInfo("Notifying " + msgStatus.connectionManagerId) + msgStatus.failure(e) + } + case None => { + logError("no messageStatus for failed message id: " + messageId) + } + } + } + } + private def wakeupSelector() { selector.wakeup() } @@ -803,29 +904,62 @@ private[nio] class ConnectionManager( : Future[Message] = { val promise = Promise[Message]() - val timeoutTask = new TimerTask { - override def run(): Unit = { + // It's important that the TimerTask doesn't capture a reference to `message`, which can cause + // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time + // at which they would originally be scheduled to run. Therefore, extract the message id + // from outside of the TimerTask closure (see SPARK-4393 for more context). + val messageId = message.id + // Keep a weak reference to the promise so that the completed promise may be garbage-collected + val promiseReference = new WeakReference(promise) + val timeoutTask: TimerTask = new TimerTask { + override def run(timeout: Timeout): Unit = { messageStatuses.synchronized { - messageStatuses.remove(message.id).foreach ( s => { - promise.failure( - new IOException("sendMessageReliably failed because ack " + - s"was not received within $ackTimeout sec")) - }) + messageStatuses.remove(messageId).foreach { s => + val e = new IOException("sendMessageReliably failed because ack " + + s"was not received within $ackTimeout sec") + val p = promiseReference.get + if (p != null) { + // Attempt to fail the promise with a Timeout exception + if (!p.tryFailure(e)) { + // If we reach here, then someone else has already signalled success or failure + // on this promise, so log a warning: + logError("Ignore error because promise is completed", e) + } + } else { + // The WeakReference was empty, which should never happen because + // sendMessageReliably's caller should have a strong reference to promise.future; + logError("Promise was garbage collected; this should never happen!", e) + } + } } } } + val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS) + val status = new MessageStatus(message, connectionManagerId, s => { - timeoutTask.cancel() - s.ackMessage match { - case None => // Indicates a failure where we either never sent or never got ACK'd - promise.failure(new IOException("sendMessageReliably failed without being ACK'd")) - case Some(ackMessage) => + timeoutTaskHandle.cancel() + s match { + case scala.util.Failure(e) => + // Indicates a failure where we either never sent or never got ACK'd + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } + case scala.util.Success(ackMessage) => if (ackMessage.hasError) { - promise.failure( - new IOException("sendMessageReliably failed with ACK that signalled a remote error")) + val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head + val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit()) + errorMsgByteBuf.get(errorMsgBytes) + val errorMsg = new String(errorMsgBytes, UTF_8) + val e = new IOException( + s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg") + if (!promise.tryFailure(e)) { + logWarning("Ignore error because promise is completed", e) + } } else { - promise.success(ackMessage) + if (!promise.trySuccess(ackMessage)) { + logWarning("Drop ackMessage because promise is completed") + } } } }) @@ -833,7 +967,6 @@ private[nio] class ConnectionManager( messageStatuses += ((message.id, status)) } - ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000) sendMessage(connectionManagerId, message) promise.future } @@ -843,7 +976,7 @@ private[nio] class ConnectionManager( } def stop() { - ackTimeoutMonitor.cancel() + ackTimeoutMonitor.stop() selectorThread.interrupt() selectorThread.join() selector.close() diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala index 0b874c2891255..fb4a979b824c3 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala @@ -22,6 +22,9 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer +import com.google.common.base.Charsets.UTF_8 + +import org.apache.spark.util.Utils private[nio] abstract class Message(val typ: Long, val id: Int) { var senderAddress: InetSocketAddress = null @@ -84,6 +87,19 @@ private[nio] object Message { createBufferMessage(new Array[ByteBuffer](0), ackId) } + /** + * Create a "negative acknowledgment" to notify a sender that an error occurred + * while processing its message. The exception's stacktrace will be formatted + * as a string, serialized into a byte array, and sent as the message payload. + */ + def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = { + val exceptionString = Utils.exceptionString(exception) + val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8)) + val errorMessage = createBufferMessage(serializedExceptionString, ackId) + errorMessage.hasError = true + errorMessage + } + def create(header: MessageChunkHeader): Message = { val newMessage: Message = header.typ match { case BUFFER_MESSAGE => new BufferMessage(header.id, diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index b389b9a2022c6..b2aec160635c7 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -19,12 +19,14 @@ package org.apache.spark.network.nio import java.nio.ByteBuffer -import scala.concurrent.Future - -import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf} import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} + +import scala.concurrent.Future /** @@ -71,20 +73,21 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa /** * Tear down the transfer service. */ - override def stop(): Unit = { + override def close(): Unit = { if (cm != null) { cm.stop() } } override def fetchBlocks( - hostName: String, + host: String, port: Int, - blockIds: Seq[String], + execId: String, + blockIds: Array[String], listener: BlockFetchingListener): Unit = { checkInit() - val cmId = new ConnectionManagerId(hostName, port) + val cmId = new ConnectionManagerId(host, port) val blockMessageArray = new BlockMessageArray(blockIds.map { blockId => BlockMessage.fromGetBlock(GetBlock(BlockId(blockId))) }) @@ -96,21 +99,33 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val bufferMessage = message.asInstanceOf[BufferMessage] val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage) - for (blockMessage <- blockMessageArray) { - if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) { - listener.onBlockFetchFailure( - new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId")) - } else { - val blockId = blockMessage.getId - val networkSize = blockMessage.getData.limit() - listener.onBlockFetchSuccess( - blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData)) + // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty. + if (blockMessageArray.isEmpty) { + blockIds.foreach { id => + listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId")) + } + } else { + for (blockMessage: BlockMessage <- blockMessageArray) { + val msgType = blockMessage.getType + if (msgType != BlockMessage.TYPE_GOT_BLOCK) { + if (blockMessage.getId != null) { + listener.onBlockFetchFailure(blockMessage.getId.toString, + new SparkException(s"Unexpected message $msgType received from $cmId")) + } + } else { + val blockId = blockMessage.getId + val networkSize = blockMessage.getData.limit() + listener.onBlockFetchSuccess( + blockId.toString, new NioManagedBuffer(blockMessage.getData)) + } } } }(cm.futureExecContext) future.onFailure { case exception => - listener.onBlockFetchFailure(exception) + blockIds.foreach { blockId => + listener.onBlockFetchFailure(blockId, exception) + } }(cm.futureExecContext) } @@ -122,12 +137,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa override def uploadBlock( hostname: String, port: Int, - blockId: String, + execId: String, + blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel) : Future[Unit] = { checkInit() - val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level) + val msg = PutBlock(blockId, blockData.nioByteBuffer(), level) val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg)) val remoteCmId = new ConnectionManagerId(hostName, port) val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage) @@ -149,19 +165,15 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get) Some(new BlockMessageArray(responseMessages).toBufferMessage) } catch { - case e: Exception => { + case e: Exception => logError("Exception handling buffer message", e) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) - } + Some(Message.createErrorMessage(e, msg.id)) } case otherMessage: Any => - logError("Unknown type message received: " + otherMessage) - val errorMessage = Message.createBufferMessage(msg.id) - errorMessage.hasError = true - Some(errorMessage) + val errorMsg = s"Received unknown message type: ${otherMessage.getClass.getName}" + logError(errorMsg) + Some(Message.createErrorMessage(new UnsupportedOperationException(errorMsg), msg.id)) } } @@ -170,13 +182,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa case BlockMessage.TYPE_PUT_BLOCK => val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) logDebug("Received [" + msg + "]") - putBlock(msg.id.toString, msg.data, msg.level) + putBlock(msg.id, msg.data, msg.level) None case BlockMessage.TYPE_GET_BLOCK => val msg = new GetBlock(blockMessage.getId) logDebug("Received [" + msg + "]") - val buffer = getBlock(msg.id.toString) + val buffer = getBlock(msg.id) if (buffer == null) { return None } @@ -186,20 +198,20 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa } } - private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) { + private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) { val startTimeMs = System.currentTimeMillis() logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes) - blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level) + blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level) logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " with data size: " + bytes.limit) } - private def getBlock(blockId: String): ByteBuffer = { + private def getBlock(blockId: BlockId): ByteBuffer = { val startTimeMs = System.currentTimeMillis() logDebug("GetBlock " + blockId + " started from " + startTimeMs) - val buffer = blockDataManager.getBlockData(blockId).orNull + val buffer = blockDataManager.getBlockData(blockId) logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - if (buffer == null) null else buffer.nioByteBuffer() + buffer.nioByteBuffer() } } diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index e2fc9c649925e..436dbed1730bc 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -44,5 +44,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.2.0-SNAPSHOT" + val SPARK_VERSION = "1.3.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala index 3155dfe165664..637492a97551b 100644 --- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.NormalDistribution /** * An ApproximateEvaluator for counts. @@ -46,7 +46,8 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double) val mean = (sum + 1 - p) / p val variance = (sum + 1) * (1 - p) / (p * p) val stdev = math.sqrt(variance) - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) + val confFactor = new NormalDistribution(). + inverseCumulativeProbability(1 - (1 - confidence) / 2) val low = mean - confFactor * stdev val high = mean + confFactor * stdev new BoundedDouble(mean, confidence, low, high) diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 8bb78123e3c9c..3ef3cc219dec6 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -24,7 +24,7 @@ import scala.collection.Map import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.NormalDistribution import org.apache.spark.util.collection.OpenHashMap @@ -55,7 +55,8 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf new HashMap[T, BoundedDouble] } else { val p = outputsMerged.toDouble / totalOutputs - val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) + val confFactor = new NormalDistribution(). + inverseCumulativeProbability(1 - (1 - confidence) / 2) val result = new JHashMap[T, BoundedDouble](sums.size) sums.foreach { case (key, sum) => val mean = (sum + 1 - p) / p diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala index d24959cba8727..787a21a61fdcf 100644 --- a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} import org.apache.spark.util.StatCounter @@ -45,9 +45,10 @@ private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double) val stdev = math.sqrt(counter.sampleVariance / counter.count) val confFactor = { if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) + new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) + val degreesOfFreedom = (counter.count - 1).toInt + new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) } } val low = mean - confFactor * stdev diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala index 92915ee66d29f..828bf96c2c0bd 100644 --- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala +++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} /** * A utility class for caching Student's T distribution values for a given confidence level @@ -25,8 +25,10 @@ import cern.jet.stat.Probability * confidence intervals for many keys. */ private[spark] class StudentTCacher(confidence: Double) { + val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation - val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2) + + val normalApprox = new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0) def get(sampleSize: Long): Double = { @@ -35,7 +37,8 @@ private[spark] class StudentTCacher(confidence: Double) { } else { val size = sampleSize.toInt if (cache(size) < 0) { - cache(size) = Probability.studentTInverse(1 - confidence, size - 1) + val tDist = new TDistribution(size - 1) + cache(size) = tDist.inverseCumulativeProbability(1 - (1 - confidence) / 2) } cache(size) } diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index d5336284571d2..1753c2561b678 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import cern.jet.stat.Probability +import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} import org.apache.spark.util.StatCounter @@ -55,9 +55,10 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double) val sumStdev = math.sqrt(sumVar) val confFactor = { if (counter.count > 100) { - Probability.normalInverse(1 - (1 - confidence) / 2) + new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2) } else { - Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt) + val degreesOfFreedom = (counter.count - 1).toInt + new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2) } } val low = sumEstimate - confFactor * sumStdev diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index b62f3fbdc4a15..9f9f10b7ebc3a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -24,14 +24,11 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} -import org.apache.spark.annotation.Experimental /** - * :: Experimental :: * A set of asynchronous RDD actions available through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ -@Experimental class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging { /** @@ -78,16 +75,18 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. + // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. + // by 50%. We also cap the estimation in the end. if (results.size == 0) { - numPartsToTry = totalParts - 1 + numPartsToTry = partsScanned * 4 } else { - numPartsToTry = (1.5 * num * partsScanned / results.size).toInt + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max(1, + (1.5 * num * partsScanned / results.size).toInt - partsScanned) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = num - results.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala new file mode 100644 index 0000000000000..6e66ddbdef788 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.hadoop.conf.{ Configurable, Configuration } +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce._ +import org.apache.spark.input.StreamFileInputFormat +import org.apache.spark.{ Partition, SparkContext } + +private[spark] class BinaryFileRDD[T]( + sc: SparkContext, + inputFormatClass: Class[_ <: StreamFileInputFormat[T]], + keyClass: Class[String], + valueClass: Class[T], + @transient conf: Configuration, + minPartitions: Int) + extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) { + + override def getPartitions: Array[Partition] = { + val inputFormat = inputFormatClass.newInstance + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + inputFormat.setMinPartitions(jobContext, minPartitions) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 2673ec22509e9..fffa1911f5bc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -84,5 +84,9 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds "Attempted to use %s after its blocks have been removed!".format(toString)) } } + + protected def getBlockIdLocations(): Map[BlockId, Seq[String]] = { + locations_ + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 4908711d17db7..1cbd684224b7c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream} import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.util.Utils private[spark] class CartesianPartition( @@ -36,7 +37,7 @@ class CartesianPartition( override val index: Int = idx @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { // Update the reference to parent split at the time of task serialization s1 = rdd1.partitions(s1Index) s2 = rdd2.partitions(s2Index) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index fabb882cdd4b3..ffc0a8a6d67eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -27,6 +27,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer} +import org.apache.spark.util.Utils import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleHandle @@ -39,7 +40,7 @@ private[spark] case class NarrowCoGroupSplitDep( ) extends CoGroupSplitDep { @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { // Update the reference to parent split at the time of task serialization split = rdd.partitions(splitIndex) oos.defaultWriteObject() diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 11ebafbf6d457..9fab1d78abb04 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -25,6 +25,7 @@ import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark._ +import org.apache.spark.util.Utils /** * Class that captures a coalesced RDD by essentially keeping track of parent partitions @@ -42,7 +43,7 @@ private[spark] case class CoalescedRDDPartition( var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_)) @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { // Update the reference to parent partition at the time of task serialization parents = parentsIndices.map(rdd.partitions(_)) oos.defaultWriteObject() diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 6b63eb23e9ee1..a157e36e2286e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -46,7 +46,6 @@ import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.util.{NextIterator, Utils} import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} - /** * A Spark split class that wraps around a Hadoop InputSplit. */ @@ -132,27 +131,47 @@ class HadoopRDD[K, V]( // used to build JobTracker ID private val createTime = new Date() + private val shouldCloneJobConf = sc.conf.get("spark.hadoop.cloneConf", "false").toBoolean + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value - if (conf.isInstanceOf[JobConf]) { - // A user-broadcasted JobConf was provided to the HadoopRDD, so always use it. - conf.asInstanceOf[JobConf] - } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - // getJobConf() has been called previously, so there is already a local cache of the JobConf - // needed by this RDD. - HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] - } else { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the - // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // Synchronize to prevent ConcurrentModificationException (Spark-1097, Hadoop-10456). + if (shouldCloneJobConf) { + // Hadoop Configuration objects are not thread-safe, which may lead to various problems if + // one job modifies a configuration while another reads it (SPARK-2546). This problem occurs + // somewhat rarely because most jobs treat the configuration as though it's immutable. One + // solution, implemented here, is to clone the Configuration object. Unfortunately, this + // clone can be very expensive. To avoid unexpected performance regressions for workloads and + // Hadoop versions that do not suffer from these thread-safety issues, this cloning is + // disabled by default. HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Cloning Hadoop Configuration") val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.map(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + if (!conf.isInstanceOf[JobConf]) { + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + } newJobConf } + } else { + if (conf.isInstanceOf[JobConf]) { + logDebug("Re-using user-broadcasted JobConf") + conf.asInstanceOf[JobConf] + } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { + logDebug("Re-using cached JobConf") + HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] + } else { + // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the + // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. + // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.map(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } + } } } @@ -192,11 +211,25 @@ class HadoopRDD[K, V]( val split = theSplit.asInstanceOf[HadoopPartition] logInfo("Input split: " + split.inputSplit) - var reader: RecordReader[K, V] = null val jobConf = getJobConf() + + val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) { + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( + split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf) + } else { + None + } + if (bytesReadCallback.isDefined) { + context.taskMetrics.inputMetrics = Some(inputMetrics) + } + + var reader: RecordReader[K, V] = null val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf) + context.stageId, theSplit.index, context.attemptId.toInt, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. @@ -204,18 +237,7 @@ class HadoopRDD[K, V]( val key: K = reader.createKey() val value: V = reader.createValue() - // Set the task input metrics. - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - try { - /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't - * always at record boundaries, so tasks may need to read into other splits to complete - * a record. */ - inputMetrics.bytesRead = split.inputSplit.value.getLength() - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) - } - context.taskMetrics.inputMetrics = Some(inputMetrics) + var recordsSinceMetricsUpdate = 0 override def getNext() = { try { @@ -224,12 +246,36 @@ class HadoopRDD[K, V]( case eof: EOFException => finished = true } + + // Update bytes read metric every few records + if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES + && bytesReadCallback.isDefined) { + recordsSinceMetricsUpdate = 0 + val bytesReadFn = bytesReadCallback.get + inputMetrics.bytesRead = bytesReadFn() + } else { + recordsSinceMetricsUpdate += 1 + } (key, value) } override def close() { try { reader.close() + if (bytesReadCallback.isDefined) { + val bytesReadFn = bytesReadCallback.get + inputMetrics.bytesRead = bytesReadFn() + } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.bytesRead = split.inputSplit.value.getLength + context.taskMetrics.inputMetrics = Some(inputMetrics) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } + } } catch { case e: Exception => { if (!Utils.inShutdown()) { @@ -276,9 +322,15 @@ class HadoopRDD[K, V]( } private[spark] object HadoopRDD extends Logging { - /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */ + /** + * Configuration's constructor is not threadsafe (see SPARK-1097 and HADOOP-10456). + * Therefore, we synchronize on this lock before calling new JobConf() or new Configuration(). + */ val CONFIGURATION_INSTANTIATION_LOCK = new Object() + /** Update the input bytes read metric each time this number of records has been read */ + val RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES = 256 + /** * The three methods below are helpers for accessing the local map, a property of the SparkEnv of * the local process. diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0e38f224ac81d..642a12c1edf6c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet} import scala.reflect.ClassTag -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.api.java.JavaSparkContext.fakeClassTag +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.util.NextIterator +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { override def index = idx @@ -125,5 +128,82 @@ object JdbcRDD { def resultSetToObjectArray(rs: ResultSet): Array[Object] = { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) } -} + trait ConnectionFactory extends Serializable { + @throws[Exception] + def getConnection: Connection + } + + /** + * Create an RDD that executes an SQL query on a JDBC connection and reads results. + * For usage example, see test case JavaAPISuite.testJavaJdbcRDD. + * + * @param connectionFactory a factory that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + * @param mapRow a function from a ResultSet to a single row of the desired result type(s). + * This should only call getInt, getString, etc; the RDD takes care of calling next. + * The default maps a ResultSet to an array of Object. + */ + def create[T]( + sc: JavaSparkContext, + connectionFactory: ConnectionFactory, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + mapRow: JFunction[ResultSet, T]): JavaRDD[T] = { + + val jdbcRDD = new JdbcRDD[T]( + sc.sc, + () => connectionFactory.getConnection, + sql, + lowerBound, + upperBound, + numPartitions, + (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag) + + new JavaRDD[T](jdbcRDD)(fakeClassTag) + } + + /** + * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is + * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD. + * + * @param connectionFactory a factory that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + */ + def create( + sc: JavaSparkContext, + connectionFactory: ConnectionFactory, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int): JavaRDD[Array[Object]] = { + + val mapRow = new JFunction[ResultSet, Array[Object]] { + override def call(resultSet: ResultSet): Array[Object] = { + resultSetToObjectArray(resultSet) + } + } + + create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 0cccdefc5ee09..e55d03d391e03 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -25,6 +25,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.spark.annotation.DeveloperApi import org.apache.spark.input.WholeTextFileInputFormat @@ -34,8 +35,10 @@ import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils +import org.apache.spark.deploy.SparkHadoopUtil private[spark] class NewHadoopPartition( rddId: Int, @@ -105,6 +108,20 @@ class NewHadoopRDD[K, V]( val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = confBroadcast.value.value + + val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // creating RecordReader, because RecordReader's constructor might read some bytes + val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback( + split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf) + } else { + None + } + if (bytesReadCallback.isDefined) { + context.taskMetrics.inputMetrics = Some(inputMetrics) + } + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance @@ -117,22 +134,11 @@ class NewHadoopRDD[K, V]( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - val inputMetrics = new InputMetrics(DataReadMethod.Hadoop) - try { - /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't - * always at record boundaries, so tasks may need to read into other splits to complete - * a record. */ - inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength() - } catch { - case e: Exception => - logWarning("Unable to get input split size in order to set task input bytes", e) - } - context.taskMetrics.inputMetrics = Some(inputMetrics) - // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) var havePair = false var finished = false + var recordsSinceMetricsUpdate = 0 override def hasNext: Boolean = { if (!finished && !havePair) { @@ -147,12 +153,39 @@ class NewHadoopRDD[K, V]( throw new java.util.NoSuchElementException("End of stream") } havePair = false + + // Update bytes read metric every few records + if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES + && bytesReadCallback.isDefined) { + recordsSinceMetricsUpdate = 0 + val bytesReadFn = bytesReadCallback.get + inputMetrics.bytesRead = bytesReadFn() + } else { + recordsSinceMetricsUpdate += 1 + } + (reader.getCurrentKey, reader.getCurrentValue) } private def close() { try { reader.close() + + // Update metrics with final amount + if (bytesReadCallback.isDefined) { + val bytesReadFn = bytesReadCallback.get + inputMetrics.bytesRead = bytesReadFn() + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + // If we can't get the bytes read from the FS stats, fall back to the split size, + // which may be inaccurate. + try { + inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength + context.taskMetrics.inputMetrics = Some(inputMetrics) + } catch { + case e: java.io.IOException => + logWarning("Unable to get input size to set InputMetrics for task", e) + } + } } catch { case e: Exception => { if (!Utils.inShutdown()) { @@ -233,7 +266,7 @@ private[spark] class WholeTextFileRDD( case _ => } val jobContext = newJobContext(conf, jobId) - inputFormat.setMaxSplitSize(jobContext, minPartitions) + inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0d97506450a7f..8c2c959e73bb6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -28,18 +28,20 @@ import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, -RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} +RecordWriter => NewRecordWriter} import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} +import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.Utils @@ -315,8 +317,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) @deprecated("Use reduceByKeyLocally", "1.0.0") def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func) - /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): Map[K, Long] = self.map(_._1).countByValue() + /** + * Count the number of elements for each key, collecting the results to a local Map. + * + * Note that this method should only be used if the resulting map is expected to be small, as + * the whole thing is loaded into the driver's memory. + * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which + * returns an RDD[T, Long] instead of a map. + */ + def countByKey(): Map[K, Long] = self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap /** * :: Experimental :: @@ -954,30 +963,40 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => { + val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) - val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) + val hadoopContext = newTaskAttemptContext(config, attemptId) val format = outfmt.newInstance format match { - case c: Configurable => c.setConf(wrappedConf.value) + case c: Configurable => c.setConf(config) case _ => () } val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) + + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { + var recordsWritten = 0L while (iter.hasNext) { val pair = iter.next() writer.write(pair._1, pair._2) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + recordsWritten += 1 } } finally { writer.close(hadoopContext) } committer.commitTask(hadoopContext) + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } 1 } : Int @@ -998,6 +1017,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsHadoopDataset(conf: JobConf) { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf + val wrappedConf = new SerializableWritable(hadoopConf) val outputFormatInstance = hadoopConf.getOutputFormat val keyClass = hadoopConf.getOutputKeyClass val valueClass = hadoopConf.getOutputValueClass @@ -1025,29 +1045,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.preSetup() val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => { + val config = wrappedConf.value // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt + val attemptNumber = (context.attemptId % Int.MaxValue).toInt + + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) - writer.setup(context.getStageId, context.getPartitionId, attemptNumber) + writer.setup(context.stageId, context.partitionId, attemptNumber) writer.open() try { - var count = 0 + var recordsWritten = 0L while (iter.hasNext) { val record = iter.next() - count += 1 writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef]) + + // Update bytes written metric every few records + maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten) + recordsWritten += 1 } } finally { writer.close() } writer.commit() + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } } self.context.runJob(self, writeToFile) writer.commitJob() } + private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) + : (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) + .map(new Path(_)) + .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) + if (bytesWrittenCallback.isDefined) { + context.taskMetrics.outputMetrics = Some(outputMetrics) + } + (outputMetrics, bytesWrittenCallback) + } + + private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long], + outputMetrics: OutputMetrics, recordsWritten: Long): Unit = { + if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0 + && bytesWrittenCallback.isDefined) { + bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() } + } + } + /** * Return an RDD with the keys of each tuple. */ @@ -1064,3 +1111,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord) } + +private[spark] object PairRDDFunctions { + val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256 +} diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 66c71bf7e8bb5..87b22de6ae697 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -48,7 +48,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( override def index: Int = slice @throws(classOf[IOException]) - private def writeObject(out: ObjectOutputStream): Unit = { + private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { val sfactory = SparkEnv.get.serializer @@ -67,7 +67,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag]( } @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream): Unit = { + private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { val sfactory = SparkEnv.get.serializer sfactory match { diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 0c2cd7a24783b..92b0641d0fb6e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream} import scala.reflect.ClassTag import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext} +import org.apache.spark.util.Utils /** * Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of @@ -38,7 +39,7 @@ class PartitionerAwareUnionRDDPartition( override def hashCode(): Int = idx @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { // Update the reference to parent partition at the time of task serialization parents = rdds.map(_.partitions(index)).toArray oos.defaultWriteObject() diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 5d77d37378458..56ac7a69be0d3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -131,7 +131,6 @@ private[spark] class PipedRDD[T: ClassTag]( // Start a thread to feed the process input from our parent's iterator new Thread("stdin writer for " + command) { override def run() { - SparkEnv.set(env) val out = new PrintWriter(proc.getOutputStream) // input the pipe context firstly diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2aba40d152e3e..3add4a76192ca 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -21,6 +21,7 @@ import java.util.{Properties, Random} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer +import scala.language.implicitConversions import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus @@ -28,6 +29,7 @@ import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable import org.apache.hadoop.io.Text +import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ @@ -43,7 +45,8 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} +import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, + SamplingUtils} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -375,7 +378,8 @@ abstract class RDD[T: ClassTag]( val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed) + new PartitionwiseSampledRDD[T, T]( + this, new BernoulliCellSampler[T](x(0), x(1)), true, seed) }.toArray } @@ -927,32 +931,15 @@ abstract class RDD[T: ClassTag]( } /** - * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final - * combine step happens locally on the master, equivalent to running a single reduce task. + * Return the count of each unique value in this RDD as a local map of (value, count) pairs. + * + * Note that this method should only be used if the resulting map is expected to be small, as + * the whole thing is loaded into the driver's memory. + * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which + * returns an RDD[T, Long] instead of a map. */ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = { - if (elementClassTag.runtimeClass.isArray) { - throw new SparkException("countByValue() does not support arrays") - } - // TODO: This should perhaps be distributed by default. - val countPartition = (iter: Iterator[T]) => { - val map = new OpenHashMap[T,Long] - iter.foreach { - t => map.changeValue(t, 1L, _ + 1L) - } - Iterator(map) - }: Iterator[OpenHashMap[T,Long]] - val mergeMaps = (m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]) => { - m2.foreach { case (key, value) => - m1.changeValue(key, value, _ + value) - } - m1 - }: OpenHashMap[T,Long] - val myResult = mapPartitions(countPartition).reduce(mergeMaps) - // Convert to a Scala mutable map - val mutableResult = scala.collection.mutable.Map[T,Long]() - myResult.foreach { case (k, v) => mutableResult.put(k, v) } - mutableResult + map(value => (value, null)).countByKey() } /** @@ -1079,15 +1066,17 @@ abstract class RDD[T: ClassTag]( // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, + // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, // interpolate the number of partitions we need to try, but overestimate it by 50%. + // We also cap the estimation in the end. if (buf.size == 0) { numPartsToTry = partsScanned * 4 } else { - numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt + // the left side of max is >=1 whenever partsScanned >= 2 + numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1) + numPartsToTry = Math.min(numPartsToTry, partsScanned * 4) } } - numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = num - buf.size val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) @@ -1109,7 +1098,7 @@ abstract class RDD[T: ClassTag]( } /** - * Returns the top K (largest) elements from this RDD as defined by the specified + * Returns the top k (largest) elements from this RDD as defined by the specified * implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example: * {{{ * sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1) @@ -1119,14 +1108,14 @@ abstract class RDD[T: ClassTag]( * // returns Array(6, 5) * }}} * - * @param num the number of top elements to return + * @param num k, the number of top elements to return * @param ord the implicit ordering for T * @return an array of top elements */ def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse) /** - * Returns the first K (smallest) elements from this RDD as defined by the specified + * Returns the first k (smallest) elements from this RDD as defined by the specified * implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]]. * For example: * {{{ @@ -1137,7 +1126,7 @@ abstract class RDD[T: ClassTag]( * // returns Array(2, 3) * }}} * - * @param num the number of top elements to return + * @param num k, the number of elements to return * @param ord the implicit ordering for T * @return an array of top elements */ @@ -1215,7 +1204,7 @@ abstract class RDD[T: ClassTag]( */ def checkpoint() { if (context.checkpointDir.isEmpty) { - throw new Exception("Checkpoint directory has not been set in the SparkContext") + throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { checkpointData = Some(new RDDCheckpointData(this)) checkpointData.get.markForCheckpoint() @@ -1322,7 +1311,7 @@ abstract class RDD[T: ClassTag]( def debugSelf (rdd: RDD[_]): Seq[String] = { import Utils.bytesToString - val persistence = storageLevel.description + val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else "" val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info => " CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format( info.numCachedPartitions, bytesToString(info.memSize), @@ -1396,3 +1385,31 @@ abstract class RDD[T: ClassTag]( new JavaRDD(this)(elementClassTag) } } + +object RDD { + + // The following implicit functions were in SparkContext before 1.2 and users had to + // `import SparkContext._` to enable them. Now we move them here to make the compiler find + // them automatically. However, we still keep the old functions in SparkContext for backward + // compatibility and forward to the following functions directly. + + implicit def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) + (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) = { + new PairRDDFunctions(rdd) + } + + implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) + + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( + rdd: RDD[(K, V)]) = + new SequenceFileRDDFunctions(rdd) + + implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( + rdd: RDD[(K, V)]) = + new OrderedRDDFunctions[K, V, (K, V)](rdd) + + implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd) + + implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) = + new DoubleRDDFunctions(rdd.map(x => num.toDouble(x))) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala index b097c30f8c231..9e8cee5331cf8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -21,8 +21,7 @@ import java.util.Random import scala.reflect.ClassTag -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark.{Partition, TaskContext} @@ -53,9 +52,11 @@ private[spark] class SampledRDD[T: ClassTag]( if (withReplacement) { // For large datasets, the expected number of occurrences of each element in a sample with // replacement is Poisson(frac). We use that to get a count for each element. - val poisson = new Poisson(frac, new DRand(split.seed)) + val poisson = new PoissonDistribution(frac) + poisson.reseedRandomGenerator(split.seed) + firstParent[T].iterator(split.prev, context).flatMap { element => - val count = poisson.nextInt() + val count = poisson.sample() if (count == 0) { Iterator.empty // Avoid object allocation when we return 0 items, which is quite often } else { diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index 0c97eb0aaa51f..aece683ff3199 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -24,6 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Utils /** * Partition for UnionRDD. @@ -48,7 +49,7 @@ private[spark] class UnionPartition[T: ClassTag]( override val index: Int = idx @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { // Update the reference to parent split at the time of task serialization parentPartition = rdd.partitions(parentRddPartitionIndex) oos.defaultWriteObject() diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index f3d30f6c9b32f..996f2cd3f34a3 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream} import scala.reflect.ClassTag import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext} +import org.apache.spark.util.Utils private[spark] class ZippedPartitionsPartition( idx: Int, @@ -34,7 +35,7 @@ private[spark] class ZippedPartitionsPartition( def partitions = partitionValues @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { + private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { // Update the reference to parent split at the time of task serialization partitionValues = rdds.map(rdd => rdd.partitions(idx)) oos.defaultWriteObject() diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala index e2c301603b4a5..8c43a559409f2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala @@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long) private[spark] class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) { - override def getPartitions: Array[Partition] = { + /** The start index of each partition. */ + @transient private val startIndices: Array[Long] = { val n = prev.partitions.size - val startIndices: Array[Long] = - if (n == 0) { - Array[Long]() - } else if (n == 1) { - Array(0L) - } else { - prev.context.runJob( - prev, - Utils.getIteratorSize _, - 0 until n - 1, // do not need to count the last partition - false - ).scanLeft(0L)(_ + _) - } + if (n == 0) { + Array[Long]() + } else if (n == 1) { + Array(0L) + } else { + prev.context.runJob( + prev, + Utils.getIteratorSize _, + 0 until n - 1, // do not need to count the last partition + allowLocal = false + ).scanLeft(0L)(_ + _) + } + } + + override def getPartitions: Array[Partition] = { firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index))) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8135cdbb4c31f..cb8ccfbdbdcbb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -124,6 +124,9 @@ class DAGScheduler( /** If enabled, we may run certain actions like take() and first() locally. */ private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false) + /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */ + private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false) + private def initializeEventProcessActor() { // blocking the thread until supervisor is started, which ensures eventProcessActor is // not null before any job is submitted @@ -446,7 +449,6 @@ class DAGScheduler( } // data structures based on StageId stageIdToStage -= stageId - logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } @@ -630,18 +632,17 @@ class DAGScheduler( protected def runLocallyWithinThread(job: ActiveJob) { var jobResult: JobResult = JobSucceeded try { - SparkEnv.set(env) val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = - new TaskContext(job.finalStage.id, job.partitions(0), 0, true) - TaskContext.setTaskContext(taskContext) + new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true) + TaskContextHelper.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } catch { case e: Exception => @@ -749,14 +750,15 @@ class DAGScheduler( localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1 if (shouldRunLocally) { // Compute very short actions like first() or take() with no parent stages locally. - listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) + listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties)) runLocally(job) } else { jobIdToActiveJob(jobId) = job activeJobs += job finalStage.resultOfJob = Some(job) - listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray, - properties)) + val stageIds = jobIdToStageIds(jobId).toArray + val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo)) + listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties)) submitStage(finalStage) } } @@ -899,6 +901,34 @@ class DAGScheduler( } } + /** Merge updates from a task to our local accumulator values */ + private def updateAccumulators(event: CompletionEvent): Unit = { + val task = event.task + val stage = stageIdToStage(task.stageId) + if (event.accumUpdates != null) { + try { + Accumulators.add(event.accumUpdates) + event.accumUpdates.foreach { case (id, partialValue) => + val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] + // To avoid UI cruft, ignore cases where value wasn't updated + if (acc.name.isDefined && partialValue != acc.zero) { + val name = acc.name.get + val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) + val stringValue = Accumulators.stringifyValue(acc.value) + stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) + event.taskInfo.accumulables += + AccumulableInfo(id, name, Some(stringPartialValue), stringValue) + } + } + } catch { + // If we see an exception during accumulator update, just log the + // error and move on. + case e: Exception => + logError(s"Failed to update accumulators for $task", e) + } + } + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -939,27 +969,6 @@ class DAGScheduler( } event.reason match { case Success => - if (event.accumUpdates != null) { - try { - Accumulators.add(event.accumUpdates) - event.accumUpdates.foreach { case (id, partialValue) => - val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]] - // To avoid UI cruft, ignore cases where value wasn't updated - if (acc.name.isDefined && partialValue != acc.zero) { - val name = acc.name.get - val stringPartialValue = Accumulators.stringifyPartialValue(partialValue) - val stringValue = Accumulators.stringifyValue(acc.value) - stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue) - event.taskInfo.accumulables += - AccumulableInfo(id, name, Some(stringPartialValue), stringValue) - } - } - } catch { - // If we see an exception during accumulator update, just log the error and move on. - case e: Exception => - logError(s"Failed to update accumulators for $task", e) - } - } listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType, event.reason, event.taskInfo, event.taskMetrics)) stage.pendingTasks -= task @@ -968,6 +977,7 @@ class DAGScheduler( stage.resultOfJob match { case Some(job) => if (!job.finished(rt.outputId)) { + updateAccumulators(event) job.finished(rt.outputId) = true job.numFinished += 1 // If the whole job has finished, remove it @@ -992,6 +1002,7 @@ class DAGScheduler( } case smt: ShuffleMapTask => + updateAccumulators(event) val status = event.result.asInstanceOf[MapStatus] val execId = status.location.executorId logDebug("ShuffleMapTask finished on " + execId) @@ -1051,7 +1062,7 @@ class DAGScheduler( logInfo("Resubmitted " + task + ", so marking it as still running") stage.pendingTasks += task - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleToMapStage(shuffleId) @@ -1061,11 +1072,13 @@ class DAGScheduler( if (runningStages.contains(failedStage)) { logInfo(s"Marking $failedStage (${failedStage.name}) as failed " + s"due to a fetch failure from $mapStage (${mapStage.name})") - markStageAsFinished(failedStage, Some("Fetch failure")) + markStageAsFinished(failedStage, Some(failureMessage)) runningStages -= failedStage } - if (failedStages.isEmpty && eventProcessActor != null) { + if (disallowStageRetryForTest) { + abortStage(failedStage, "Fetch failure will not retry stage due to testing config") + } else if (failedStages.isEmpty && eventProcessActor != null) { // Don't schedule an event to resubmit failed stages if failed isn't empty, because // in that case the event will already have been scheduled. eventProcessActor may be // null during unit tests. @@ -1078,7 +1091,6 @@ class DAGScheduler( } failedStages += failedStage failedStages += mapStage - // Mark the map whose fetch failed as broken in the map stage if (mapId != -1) { mapStage.removeOutputLoc(mapId, bmAddress) @@ -1087,10 +1099,10 @@ class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - handleExecutorLost(bmAddress.executorId, Some(task.epoch)) + handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch)) } - case ExceptionFailure(className, description, stackTrace, metrics) => + case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) => // Do nothing here, left up to the TaskScheduler to decide how to handle user failures case TaskResultLost => @@ -1107,25 +1119,35 @@ class DAGScheduler( * Responds to an executor being lost. This is called inside the event loop, so it assumes it can * modify the scheduler's internal state. Use executorLost() to post a loss event from outside. * + * We will also assume that we've lost all shuffle blocks associated with the executor if the + * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed + * occurred, in which case we presume all shuffle data related to this executor to be lost. + * * Optionally the epoch during which the failure was caught can be passed to avoid allowing * stray fetch failures from possibly retriggering the detection of a node as lost. */ - private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) { + private[scheduler] def handleExecutorLost( + execId: String, + fetchFailed: Boolean, + maybeEpoch: Option[Long] = None) { val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch) if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) { failedEpoch(execId) = currentEpoch logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch)) blockManagerMaster.removeExecutor(execId) - // TODO: This will be really slow if we keep accumulating shuffle map stages - for ((shuffleId, stage) <- shuffleToMapStage) { - stage.removeOutputsOnExecutor(execId) - val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray - mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) - } - if (shuffleToMapStage.isEmpty) { - mapOutputTracker.incrementEpoch() + + if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) { + // TODO: This will be really slow if we keep accumulating shuffle map stages + for ((shuffleId, stage) <- shuffleToMapStage) { + stage.removeOutputsOnExecutor(execId) + val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray + mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true) + } + if (shuffleToMapStage.isEmpty) { + mapOutputTracker.incrementEpoch() + } + clearCacheLocs() } - clearCacheLocs() } else { logDebug("Additional executor lost message for " + execId + "(epoch " + currentEpoch + ")") @@ -1383,7 +1405,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule dagScheduler.handleExecutorAdded(execId, host) case ExecutorLost(execId) => - dagScheduler.handleExecutorLost(execId) + dagScheduler.handleExecutorLost(execId, fetchFailed = false) case BeginEvent(task, taskInfo) => dagScheduler.handleBeginEvent(task, taskInfo) diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 100c9ba9b7809..597dbc884913c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -142,7 +142,7 @@ private[spark] object EventLoggingListener extends Logging { val SPARK_VERSION_PREFIX = "SPARK_VERSION_" val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_" val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" - val LOG_FILE_PERMISSIONS = FsPermission.createImmutable(Integer.parseInt("770", 8).toShort) + val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort) // A cache for compression codecs to avoid creating the same codec many times private val codecMap = new mutable.HashMap[String, CompressionCodec] diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 54904bffdf10b..3bb54855bae44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -158,6 +158,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " INPUT_BYTES=" + metrics.bytesRead case None => "" } + val outputMetrics = taskMetrics.outputMetrics match { + case Some(metrics) => + " OUTPUT_BYTES=" + metrics.bytesWritten + case None => "" + } val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match { case Some(metrics) => " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + @@ -173,7 +178,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime case None => "" } - stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + + stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics + shuffleReadMetrics + writeMetrics) } @@ -215,7 +220,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId stageLogInfo(taskEnd.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId) => + case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) => taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + mapId + " REDUCE_ID=" + reduceId diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index e25096ea92d70..01d5943d777f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,7 +19,10 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} +import org.roaringbitmap.RoaringBitmap + import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the @@ -29,7 +32,12 @@ private[spark] sealed trait MapStatus { /** Location where this task was run. */ def location: BlockManagerId - /** Estimated size for the reduce block, in bytes. */ + /** + * Estimated size for the reduce block, in bytes. + * + * If a block is non-empty, then this method MUST return a non-zero size. This invariant is + * necessary for correctness, since block fetchers are allowed to skip zero-size blocks. + */ def getSizeForBlock(reduceId: Int): Long } @@ -38,7 +46,7 @@ private[spark] object MapStatus { def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > 2000) { - new HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes) } else { new CompressedMapStatus(loc, uncompressedSizes) } @@ -98,13 +106,13 @@ private[spark] class CompressedMapStatus( MapStatus.decompressSize(compressedSizes(reduceId)) } - override def writeExternal(out: ObjectOutput): Unit = { + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) } - override def readExternal(in: ObjectInput): Unit = { + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) val len = in.readInt() compressedSizes = new Array[Byte](len) @@ -112,35 +120,80 @@ private[spark] class CompressedMapStatus( } } - /** - * A [[MapStatus]] implementation that only stores the average size of the blocks. + * A [[MapStatus]] implementation that only stores the average size of non-empty blocks, + * plus a bitmap for tracking which blocks are non-empty. During serialization, this bitmap + * is compressed. * - * @param loc location where the task is being executed. - * @param avgSize average size of all the blocks + * @param loc location where the task is being executed + * @param numNonEmptyBlocks the number of non-empty blocks + * @param emptyBlocks a bitmap tracking which blocks are empty + * @param avgSize average size of the non-empty blocks */ -private[spark] class HighlyCompressedMapStatus( +private[spark] class HighlyCompressedMapStatus private ( private[this] var loc: BlockManagerId, + private[this] var numNonEmptyBlocks: Int, + private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long) extends MapStatus with Externalizable { - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.sum / uncompressedSizes.length) - } + // loc could be null when the default constructor is called during deserialization + require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0, + "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, 0L) // For deserialization only + protected def this() = this(null, -1, null, -1) // For deserialization only override def location: BlockManagerId = loc - override def getSizeForBlock(reduceId: Int): Long = avgSize + override def getSizeForBlock(reduceId: Int): Long = { + if (emptyBlocks.contains(reduceId)) { + 0 + } else { + avgSize + } + } - override def writeExternal(out: ObjectOutput): Unit = { + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) + emptyBlocks.writeExternal(out) out.writeLong(avgSize) } - override def readExternal(in: ObjectInput): Unit = { + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { loc = BlockManagerId(in) + emptyBlocks = new RoaringBitmap() + emptyBlocks.readExternal(in) avgSize = in.readLong() } } + +private[spark] object HighlyCompressedMapStatus { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + // We must keep track of which blocks are empty so that we don't report a zero-sized + // block as being non-empty (or vice-versa) when using the average block size. + var i = 0 + var numNonEmptyBlocks: Int = 0 + var totalSize: Long = 0 + // From a compression standpoint, it shouldn't matter whether we track empty or non-empty + // blocks. From a performance standpoint, we benefit from tracking empty blocks because + // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions. + val emptyBlocks = new RoaringBitmap() + val totalNumBlocks = uncompressedSizes.length + while (i < totalNumBlocks) { + var size = uncompressedSizes(i) + if (size > 0) { + numNonEmptyBlocks += 1 + totalSize += size + } else { + emptyBlocks.add(i) + } + i += 1 + } + val avgSize = if (numNonEmptyBlocks > 0) { + totalSize / numNonEmptyBlocks + } else { + 0 + } + new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize) + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 86afe3bd5265f..b62b0c1312693 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -56,8 +56,15 @@ case class SparkListenerTaskEnd( extends SparkListenerEvent @DeveloperApi -case class SparkListenerJobStart(jobId: Int, stageIds: Seq[Int], properties: Properties = null) - extends SparkListenerEvent +case class SparkListenerJobStart( + jobId: Int, + stageInfos: Seq[StageInfo], + properties: Properties = null) + extends SparkListenerEvent { + // Note: this is here for backwards-compatibility with older versions of this event which + // only stored stageIds and not StageInfos: + val stageIds: Seq[Int] = stageInfos.map(_.stageId) +} @DeveloperApi case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index 071568cdfb429..cc13f57a49b89 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -102,6 +102,11 @@ private[spark] class Stage( } } + /** + * Removes all shuffle outputs associated with this executor. Note that this will also remove + * outputs which are served by an external shuffle server (if one exists), as they are still + * registered with this execId. + */ def removeOutputsOnExecutor(execId: String) { var becameUnavailable = false for (partition <- 0 until numPartitions) { @@ -131,4 +136,9 @@ private[spark] class Stage( override def toString = "Stage " + id override def hashCode(): Int = id + + override def equals(other: Any): Boolean = other match { + case stage: Stage => stage != null && stage.id == id + case _ => false + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index c6e47c84a0cb2..2552d03d18d06 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.TaskContext +import org.apache.spark.{TaskContextHelper, TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.ByteBufferInputStream @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, false) - TaskContext.setTaskContext(context) + context = new TaskContextImpl(stageId, partitionId, attemptId, false) + TaskContextHelper.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { @@ -56,7 +56,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex runTask(context) } finally { context.markTaskCompleted() - TaskContext.unset() + TaskContextHelper.unset() } } @@ -70,7 +70,7 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex var metrics: Option[TaskMetrics] = None // Task context, to be initialized in run(). - @transient protected var context: TaskContext = _ + @transient protected var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala index d49d8fb887007..1f114a0207f7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala @@ -31,8 +31,8 @@ import org.apache.spark.util.Utils private[spark] sealed trait TaskResult[T] /** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */ -private[spark] -case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable +private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int) + extends TaskResult[T] with Serializable /** A TaskResult that contains the task's return value and accumulator updates. */ private[spark] @@ -42,7 +42,7 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long def this() = this(null.asInstanceOf[ByteBuffer], null, null) - override def writeExternal(out: ObjectOutput) { + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeInt(valueBytes.remaining); Utils.writeByteBuffer(valueBytes, out) @@ -55,7 +55,7 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long out.writeObject(metrics) } - override def readExternal(in: ObjectInput) { + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { val blen = in.readInt() val byteVal = new Array[Byte](blen) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 3f345ceeaaf7a..819b51e12ad8c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul getTaskResultExecutor.execute(new Runnable { override def run(): Unit = Utils.logUncaughtExceptions { try { - val result = serializer.get().deserialize[TaskResult[_]](serializedData) match { - case directResult: DirectTaskResult[_] => directResult - case IndirectTaskResult(blockId) => + val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match { + case directResult: DirectTaskResult[_] => + if (!taskSetManager.canFetchMoreResults(serializedData.limit())) { + return + } + (directResult, serializedData.limit()) + case IndirectTaskResult(blockId, size) => + if (!taskSetManager.canFetchMoreResults(size)) { + // dropped by executor if size is larger than maxResultSize + sparkEnv.blockManager.master.removeBlock(blockId) + return + } logDebug("Fetching indirect task result for TID %s".format(tid)) scheduler.handleTaskGettingResult(taskSetManager, tid) val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId) @@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( serializedTaskResult.get) sparkEnv.blockManager.master.removeBlock(blockId) - deserializedResult + (deserializedResult, size) } - result.metrics.resultSize = serializedData.limit() + + result.metrics.resultSize = size scheduler.handleSuccessfulTask(taskSetManager, tid, result) } catch { case cnf: ClassNotFoundException => @@ -93,7 +103,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } } catch { case cnd: ClassNotFoundException => - // Log an error but keep going here -- the task failed, so not catastropic if we can't + // Log an error but keep going here -- the task failed, so not catastrophic if we can't // deserialize the reason. val loader = Utils.getContextOrSparkClassLoader logError( diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index a129a434c9a1a..f095915352b17 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockManagerId /** * Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl. - * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks + * This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks * for a single SparkContext. These schedulers get sets of tasks submitted to them from the * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running * them, retrying if there are failures, and mitigating stragglers. They return events to the @@ -41,7 +41,7 @@ private[spark] trait TaskScheduler { // Invoked after system has successfully initialized (typically in spark context). // Yarn uses this to bootstrap allocation of resources based on preferred locations, - // wait for slave registerations, etc. + // wait for slave registrations, etc. def postStartHook() { } // Disconnect from the cluster. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 4dc550413c13c..cd3c015321e85 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -34,7 +34,6 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.util.Utils import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId -import akka.actor.Props /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. @@ -216,13 +215,12 @@ private[spark] class TaskSchedulerImpl( * that tasks are balanced across the cluster. */ def resourceOffers(offers: Seq[WorkerOffer]): Seq[Seq[TaskDescription]] = synchronized { - SparkEnv.set(sc.env) - // Mark each slave as alive and remember its hostname // Also track if new executor is added var newExecAvail = false for (o <- offers) { executorIdToHost(o.executorId) = o.host + activeExecutorIds += o.executorId if (!executorsByHost.contains(o.host)) { executorsByHost(o.host) = new HashSet[String]() executorAdded(o.executorId, o.host) @@ -263,7 +261,6 @@ private[spark] class TaskSchedulerImpl( val tid = task.taskId taskIdToTaskSetId(tid) = taskSet.taskSet.id taskIdToExecutorId(tid) = execId - activeExecutorIds += execId executorsByHost(host) += execId availableCpus(i) -= CPUS_PER_TASK assert(availableCpus(i) >= 0) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a6c23fc85a1b0..cabdc655f89bf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -23,13 +23,12 @@ import java.util.Arrays import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.math.max -import scala.math.min +import scala.math.{min, max} import org.apache.spark._ -import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{Clock, SystemClock} +import org.apache.spark.TaskState.TaskState +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of @@ -68,6 +67,9 @@ private[spark] class TaskSetManager( val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75) val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5) + // Limit of bytes for total size of results (default is 1GB) + val maxResultSize = Utils.getMaxResultSize(conf) + // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -89,6 +91,8 @@ private[spark] class TaskSetManager( var stageId = taskSet.stageId var name = "TaskSet_" + taskSet.stageId.toString var parent: Pool = null + var totalResultSize = 0L + var calculatedTasks = 0 val runningTasksSet = new HashSet[Long] override def runningTasks = runningTasksSet.size @@ -515,12 +519,33 @@ private[spark] class TaskSetManager( index } + /** + * Marks the task as getting result and notifies the DAG Scheduler + */ def handleTaskGettingResult(tid: Long) = { val info = taskInfos(tid) info.markGettingResult() sched.dagScheduler.taskGettingResult(info) } + /** + * Check whether has enough quota to fetch the result with `size` bytes + */ + def canFetchMoreResults(size: Long): Boolean = synchronized { + totalResultSize += size + calculatedTasks += 1 + if (maxResultSize > 0 && totalResultSize > maxResultSize) { + val msg = s"Total size of serialized results of ${calculatedTasks} tasks " + + s"(${Utils.bytesToString(totalResultSize)}) is bigger than spark.driver.maxResultSize " + + s"(${Utils.bytesToString(maxResultSize)})" + logError(msg) + abort(msg) + false + } else { + true + } + } + /** * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ @@ -687,10 +712,11 @@ private[spark] class TaskSetManager( addPendingTask(index, readding=true) } - // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage. + // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage, + // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor // so we would need to rerun these tasks on other executors. - if (tasks(0).isInstanceOf[ShuffleMapTask]) { + if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (successful(index)) { @@ -706,7 +732,7 @@ private[spark] class TaskSetManager( } // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId)) } // recalculate valid locality levels and waits when executor is lost recomputeLocality() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index fb8160abc59db..1da6fe976da5b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -66,7 +66,19 @@ private[spark] object CoarseGrainedClusterMessages { case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage - case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase :String) + // Exchanged between the driver and the AM in Yarn client mode + case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String) extends CoarseGrainedClusterMessage + // Messages exchanged between the driver and the cluster manager for executor allocation + // In Yarn mode, these are exchanged between the driver and the AM + + case object RegisterClusterManager extends CoarseGrainedClusterMessage + + // Request executors by specifying the new total number of executors desired + // This includes executors already pending or running + case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage + + case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 59aed6b72fe42..88b196ac64368 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -31,7 +31,6 @@ import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState} import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} -import org.apache.spark.ui.JettyUtils /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -42,11 +41,12 @@ import org.apache.spark.ui.JettyUtils * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) extends SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed var totalCoreCount = new AtomicInteger(0) + // Total number of executors that are currently registered var totalRegisteredExecutors = new AtomicInteger(0) val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) @@ -61,10 +61,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() + private val executorDataMap = new HashMap[String, ExecutorData] + + // Number of executors requested from the cluster manager that have not registered yet + private var numPendingExecutors = 0 + + // Executors we have requested the cluster manager to kill that have not died yet + private val executorsPendingToRemove = new HashSet[String] + class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { override protected def log = CoarseGrainedSchedulerBackend.this.log private val addressToExecutorId = new HashMap[Address, String] - private val executorDataMap = new HashMap[String, ExecutorData] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -84,12 +91,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } else { logInfo("Registered executor: " + sender + " with ID " + executorId) sender ! RegisteredExecutor - executorDataMap.put(executorId, new ExecutorData(sender, sender.path.address, - Utils.parseHostPort(hostPort)._1, cores, cores)) addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) + val (host, _) = Utils.parseHostPort(hostPort) + val data = new ExecutorData(sender, sender.path.address, host, cores, cores) + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap.put(executorId, data) + if (numPendingExecutors > 0) { + numPendingExecutors -= 1 + logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") + } + } makeOffers() } @@ -111,7 +127,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A makeOffers() case KillTask(taskId, executorId, interruptThread) => - executorDataMap(executorId).executorActor ! KillTask(taskId, executorId, interruptThread) + executorDataMap.get(executorId) match { + case Some(executorInfo) => + executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + case None => + // Ignoring the task kill since the executor is not registered. + logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") + } case StopDriver => sender ! true @@ -128,10 +150,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A removeExecutor(executorId, reason) sender ! true - case AddWebUIFilter(filterName, filterParams, proxyBase) => - addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true - case DisassociatedEvent(_, address, _) => addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) @@ -183,13 +201,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } // Remove a disconnected slave from the cluster - def removeExecutor(executorId: String, reason: String) { + def removeExecutor(executorId: String, reason: String): Unit = { executorDataMap.get(executorId) match { case Some(executorInfo) => - executorDataMap -= executorId + // This must be synchronized because variables mutated + // in this block are read when requesting executors + CoarseGrainedSchedulerBackend.this.synchronized { + executorDataMap -= executorId + executorsPendingToRemove -= executorId + } totalCoreCount.addAndGet(-executorInfo.totalCores) + totalRegisteredExecutors.addAndGet(-1) scheduler.executorLost(executorId, SlaveLost(reason)) - case None => logError(s"Asked to remove non existant executor $executorId") + case None => logError(s"Asked to remove non-existent executor $executorId") } } } @@ -274,21 +298,62 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A false } - // Add filters to the SparkUI - def addWebUIFilter(filterName: String, filterParams: Map[String, String], proxyBase: String) { - if (proxyBase != null && proxyBase.nonEmpty) { - System.setProperty("spark.ui.proxyBase", proxyBase) - } + /** + * Return the number of executors currently registered with this backend. + */ + def numExistingExecutors: Int = executorDataMap.size + + /** + * Request an additional number of executors from the cluster manager. + * Return whether the request is acknowledged. + */ + final def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized { + logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") + logDebug(s"Number of pending executors is now $numPendingExecutors") + numPendingExecutors += numAdditionalExecutors + // Account for executors pending to be added or removed + val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size + doRequestTotalExecutors(newTotal) + } - val hasFilter = (filterName != null && filterName.nonEmpty && - filterParams != null && filterParams.nonEmpty) - if (hasFilter) { - logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") - conf.set("spark.ui.filters", filterName) - filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } - scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } + /** + * Request executors from the cluster manager by specifying the total number desired, + * including existing pending and running executors. + * + * The semantics here guarantee that we do not over-allocate executors for this application, + * since a later request overrides the value of any prior request. The alternative interface + * of requesting a delta of executors risks double counting new executors when there are + * insufficient resources to satisfy the first request. We make the assumption here that the + * cluster manager will eventually fulfill all requests when resources free up. + * + * Return whether the request is acknowledged. + */ + protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false + + /** + * Request that the cluster manager kill the specified executors. + * Return whether the kill request is acknowledged. + */ + final def killExecutors(executorIds: Seq[String]): Boolean = { + logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}") + val filteredExecutorIds = new ArrayBuffer[String] + executorIds.foreach { id => + if (executorDataMap.contains(id)) { + filteredExecutorIds += id + } else { + logWarning(s"Executor to kill $id does not exist!") + } } + executorsPendingToRemove ++= filteredExecutorIds + doKillExecutors(filteredExecutorIds) } + + /** + * Kill the given list of executors through the cluster manager. + * Return whether the kill request is acknowledged. + */ + protected def doKillExecutors(executorIds: Seq[String]): Boolean = false + } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ed209d195ec9d..8c7de75600b5f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -51,7 +51,8 @@ private[spark] class SparkDeploySchedulerBackend( conf.get("spark.driver.host"), conf.get("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{WORKER_URL}}") + val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}", + "{{WORKER_URL}}") val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") .map(Utils.splitCommandString).getOrElse(Seq.empty) val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp => diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala new file mode 100644 index 0000000000000..50721b9d6cd6c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import akka.actor.{Actor, ActorRef, Props} +import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} + +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.scheduler.TaskSchedulerImpl +import org.apache.spark.ui.JettyUtils +import org.apache.spark.util.AkkaUtils + +/** + * Abstract Yarn scheduler backend that contains common logic + * between the client and cluster Yarn scheduler backends. + */ +private[spark] abstract class YarnSchedulerBackend( + scheduler: TaskSchedulerImpl, + sc: SparkContext) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { + minRegisteredRatio = 0.8 + } + + protected var totalExpectedExecutors = 0 + + private val yarnSchedulerActor: ActorRef = + actorSystem.actorOf( + Props(new YarnSchedulerActor), + name = YarnSchedulerBackend.ACTOR_NAME) + + private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) + + /** + * Request executors from the ApplicationMaster by specifying the total number desired. + * This includes executors already pending or running. + */ + override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { + AkkaUtils.askWithReply[Boolean]( + RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + } + + /** + * Request that the ApplicationMaster kill the specified executors. + */ + override def doKillExecutors(executorIds: Seq[String]): Boolean = { + AkkaUtils.askWithReply[Boolean]( + KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + } + + override def sufficientResourcesRegistered(): Boolean = { + totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio + } + + /** + * Add filters to the SparkUI. + */ + private def addWebUIFilter( + filterName: String, + filterParams: Map[String, String], + proxyBase: String): Unit = { + if (proxyBase != null && proxyBase.nonEmpty) { + System.setProperty("spark.ui.proxyBase", proxyBase) + } + + val hasFilter = + filterName != null && filterName.nonEmpty && + filterParams != null && filterParams.nonEmpty + if (hasFilter) { + logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase") + conf.set("spark.ui.filters", filterName) + filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) } + scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) } + } + } + + /** + * An actor that communicates with the ApplicationMaster. + */ + private class YarnSchedulerActor extends Actor { + private var amActor: Option[ActorRef] = None + + override def preStart(): Unit = { + // Listen for disassociation events + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + + override def receive = { + case RegisterClusterManager => + logInfo(s"ApplicationMaster registered as $sender") + amActor = Some(sender) + + case r: RequestExecutors => + amActor match { + case Some(actor) => + sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + case None => + logWarning("Attempted to request executors before the AM has registered!") + sender ! false + } + + case k: KillExecutors => + amActor match { + case Some(actor) => + sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + case None => + logWarning("Attempted to kill executors before the AM has registered!") + sender ! false + } + + case AddWebUIFilter(filterName, filterParams, proxyBase) => + addWebUIFilter(filterName, filterParams, proxyBase) + sender ! true + + case d: DisassociatedEvent => + if (amActor.isDefined && sender == amActor.get) { + logWarning(s"ApplicationMaster has disassociated: $d") + } + } + } +} + +private[spark] object YarnSchedulerBackend { + val ACTOR_NAME = "YarnScheduler" +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 90828578cd88f..5289661eb896b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -31,6 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException} import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils /** * A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds @@ -92,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = CoarseMesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { { val ret = driver.run() @@ -120,16 +121,18 @@ private[spark] class CoarseMesosSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions") + val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "") - val libraryPathOption = "spark.executor.extraLibraryPath" - val extraLibraryPath = conf.getOption(libraryPathOption).map(p => s"-Djava.library.path=$p") - val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ") + // Set the environment variable through a command prefix + // to append to the existing value of the variable + val prefixEnv = conf.getOption("spark.executor.extraLibraryPath").map { p => + Utils.libraryPathEnvPrefix(Seq(p)) + }.getOrElse("") environment.addVariables( Environment.Variable.newBuilder() .setName("SPARK_EXECUTOR_OPTS") - .setValue(extraOpts) + .setValue(extraJavaOpts) .build()) sc.executorEnvs.foreach { case (key, value) => @@ -150,17 +153,18 @@ private[spark] class CoarseMesosSchedulerBackend( if (uri == null) { val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath command.setValue( - "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d".format( - runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores)) + "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format( + prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue, + offer.getHostname, numCores, appId)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head command.setValue( - ("cd %s*; " + - "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d") - .format(basename, driverUrl, offer.getSlaveId.getValue, - offer.getHostname, numCores)) + ("cd %s*; %s " + + "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s") + .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue, + offer.getHostname, numCores, appId)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } command.build() @@ -238,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Build a Mesos resource protobuf object */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index b11786368e661..10e6886c16a4f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend( setDaemon(true) override def run() { val scheduler = MesosSchedulerBackend.this - val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build() + val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build() driver = new MesosSchedulerDriver(scheduler, fwInfo, master) try { val ret = driver.run() @@ -98,15 +98,16 @@ private[spark] class MesosSchedulerBackend( environment.addVariables( Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build()) } - val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions") - val extraLibraryPath = sc.conf.getOption("spark.executor.extraLibraryPath").map { lp => - s"-Djava.library.path=$lp" - } - val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ") + val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("") + + val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p => + Utils.libraryPathEnvPrefix(Seq(p)) + }.getOrElse("") + environment.addVariables( Environment.Variable.newBuilder() .setName("SPARK_EXECUTOR_OPTS") - .setValue(extraOpts) + .setValue(extraJavaOpts) .build()) sc.executorEnvs.foreach { case (key, value) => environment.addVariables(Environment.Variable.newBuilder() @@ -118,12 +119,13 @@ private[spark] class MesosSchedulerBackend( .setEnvironment(environment) val uri = sc.conf.get("spark.executor.uri", null) if (uri == null) { - command.setValue(new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath) + val executorPath = new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath + command.setValue("%s %s".format(prefixEnv, executorPath)) } else { // Grab everything to the first '.'. We'll use that and '*' to // glob the directory "correctly". val basename = uri.split('/').last.split('.').head - command.setValue("cd %s*; ./sbin/spark-executor".format(basename)) + command.setValue("cd %s*; %s ./sbin/spark-executor".format(basename, prefixEnv)) command.addUris(CommandInfo.URI.newBuilder().setValue(uri)) } val cpus = Resource.newBuilder() @@ -164,29 +166,16 @@ private[spark] class MesosSchedulerBackend( execArgs } - private def setClassLoader(): ClassLoader = { - val oldClassLoader = Thread.currentThread.getContextClassLoader - Thread.currentThread.setContextClassLoader(classLoader) - oldClassLoader - } - - private def restoreClassLoader(oldClassLoader: ClassLoader) { - Thread.currentThread.setContextClassLoader(oldClassLoader) - } - override def offerRescinded(d: SchedulerDriver, o: OfferID) {} override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) registeredLock.synchronized { isRegistered = true registeredLock.notifyAll() } - } finally { - restoreClassLoader(oldClassLoader) } } @@ -198,6 +187,16 @@ private[spark] class MesosSchedulerBackend( } } + private def inClassLoader()(fun: => Unit) = { + val oldClassLoader = Thread.currentThread.getContextClassLoader + Thread.currentThread.setContextClassLoader(classLoader) + try { + fun + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) + } + } + override def disconnected(d: SchedulerDriver) {} override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {} @@ -208,66 +207,70 @@ private[spark] class MesosSchedulerBackend( * tasks are balanced across the cluster. */ override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) { - val oldClassLoader = setClassLoader() - try { - synchronized { - // Build a big list of the offerable workers, and remember their indices so that we can - // figure out which Offer to reply to for each worker - val offerableWorkers = new ArrayBuffer[WorkerOffer] - val offerableIndices = new HashMap[String, Int] - - def sufficientOffer(o: Offer) = { - val mem = getResource(o.getResourcesList, "mem") - val cpus = getResource(o.getResourcesList, "cpus") - val slaveId = o.getSlaveId.getValue - (mem >= MemoryUtils.calculateTotalMemory(sc) && - // need at least 1 for executor, 1 for task - cpus >= 2 * scheduler.CPUS_PER_TASK) || - (slaveIdsWithExecutors.contains(slaveId) && - cpus >= scheduler.CPUS_PER_TASK) - } + inClassLoader() { + // Fail-fast on offers we know will be rejected + val (usableOffers, unUsableOffers) = offers.partition { o => + val mem = getResource(o.getResourcesList, "mem") + val cpus = getResource(o.getResourcesList, "cpus") + val slaveId = o.getSlaveId.getValue + // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK? + (mem >= MemoryUtils.calculateTotalMemory(sc) && + // need at least 1 for executor, 1 for task + cpus >= 2 * scheduler.CPUS_PER_TASK) || + (slaveIdsWithExecutors.contains(slaveId) && + cpus >= scheduler.CPUS_PER_TASK) + } - for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) { - val slaveId = offer.getSlaveId.getValue - offerableIndices.put(slaveId, index) - val cpus = if (slaveIdsWithExecutors.contains(slaveId)) { - getResource(offer.getResourcesList, "cpus").toInt - } else { - // If the executor doesn't exist yet, subtract CPU for executor - getResource(offer.getResourcesList, "cpus").toInt - - scheduler.CPUS_PER_TASK - } - offerableWorkers += new WorkerOffer( - offer.getSlaveId.getValue, - offer.getHostname, - cpus) + val workerOffers = usableOffers.map { o => + val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) { + getResource(o.getResourcesList, "cpus").toInt + } else { + // If the executor doesn't exist yet, subtract CPU for executor + // TODO(pwendell): Should below just subtract "1"? + getResource(o.getResourcesList, "cpus").toInt - + scheduler.CPUS_PER_TASK } + new WorkerOffer( + o.getSlaveId.getValue, + o.getHostname, + cpus) + } + + val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap + + val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]] - // Call into the TaskSchedulerImpl - val taskLists = scheduler.resourceOffers(offerableWorkers) - - // Build a list of Mesos tasks for each slave - val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]()) - for ((taskList, index) <- taskLists.zipWithIndex) { - if (!taskList.isEmpty) { - for (taskDesc <- taskList) { - val slaveId = taskDesc.executorId - val offerNum = offerableIndices(slaveId) - slaveIdsWithExecutors += slaveId - taskIdToSlaveId(taskDesc.taskId) = slaveId - mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId)) - } + val slavesIdsOfAcceptedOffers = HashSet[String]() + + // Call into the TaskSchedulerImpl + val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty) + acceptedOffers + .foreach { offer => + offer.foreach { taskDesc => + val slaveId = taskDesc.executorId + slaveIdsWithExecutors += slaveId + slavesIdsOfAcceptedOffers += slaveId + taskIdToSlaveId(taskDesc.taskId) = slaveId + mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo]) + .add(createMesosTask(taskDesc, slaveId)) } } - // Reply to the offers - val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? - for (i <- 0 until offers.size) { - d.launchTasks(Collections.singleton(offers(i).getId), mesosTasks(i), filters) - } + // Reply to the offers + val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout? + + mesosTasks.foreach { case (slaveId, tasks) => + d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) } - } finally { - restoreClassLoader(oldClassLoader) + + // Decline offers that weren't used + // NOTE: This logic assumes that we only get a single offer for each host in a given batch + for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) { + d.declineOffer(o.getId) + } + + // Decline offers we ruled out immediately + unUsableOffers.foreach(o => d.declineOffer(o.getId)) } } @@ -276,8 +279,7 @@ private[spark] class MesosSchedulerBackend( for (r <- res if r.getName == name) { return r.getScalar.getValue } - // If we reached here, no resource with the required name was present - throw new IllegalArgumentException("No resource called " + name + " in " + res) + 0 } /** Turn a Spark TaskDescription into a Mesos task */ @@ -307,8 +309,7 @@ private[spark] class MesosSchedulerBackend( } override def statusUpdate(d: SchedulerDriver, status: TaskStatus) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { val tid = status.getTaskId.getValue.toLong val state = TaskState.fromMesos(status.getState) synchronized { @@ -321,18 +322,13 @@ private[spark] class MesosSchedulerBackend( } } scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer) - } finally { - restoreClassLoader(oldClassLoader) } } override def error(d: SchedulerDriver, message: String) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { logError("Mesos error: " + message) scheduler.error(message) - } finally { - restoreClassLoader(oldClassLoader) } } @@ -349,15 +345,12 @@ private[spark] class MesosSchedulerBackend( override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {} private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { - val oldClassLoader = setClassLoader() - try { + inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) synchronized { slaveIdsWithExecutors -= slaveId.getValue } scheduler.executorLost(slaveId.getValue, reason) - } finally { - restoreClassLoader(oldClassLoader) } } @@ -372,6 +365,13 @@ private[spark] class MesosSchedulerBackend( recordSlaveLost(d, slaveId, ExecutorExited(status)) } + override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { + driver.killTask( + TaskID.newBuilder() + .setValue(taskId.toString).build() + ) + } + // TODO: query Mesos for number of cores override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 58b78f041cd85..a2f1f14264a99 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import akka.actor.{Actor, ActorRef, Props} -import org.apache.spark.{Logging, SparkEnv, TaskState} +import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} @@ -47,11 +47,11 @@ private[spark] class LocalActor( private var freeCores = totalCores - private val localExecutorId = "localhost" + private val localExecutorId = SparkContext.DRIVER_IDENTIFIER private val localExecutorHostname = "localhost" val executor = new Executor( - localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true) + localExecutorId, localExecutorHostname, scheduler.conf.getAll, totalCores, isLocal = true) override def receiveWithLogging = { case ReviveOffers => diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 554a33ce7f1a6..662a7b91248aa 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -117,11 +117,11 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { new JavaSerializerInstance(counterReset, classLoader) } - override def writeExternal(out: ObjectOutput) { + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeInt(counterReset) } - override def readExternal(in: ObjectInput) { + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { counterReset = in.readInt() } } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index d6386f8c06fff..621a951c27d07 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -53,7 +53,18 @@ class KryoSerializer(conf: SparkConf) private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024 private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true) private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false) - private val registrator = conf.getOption("spark.kryo.registrator") + private val userRegistrator = conf.getOption("spark.kryo.registrator") + private val classesToRegister = conf.get("spark.kryo.classesToRegister", "") + .split(',') + .filter(!_.isEmpty) + .map { className => + try { + Class.forName(className) + } catch { + case e: Exception => + throw new SparkException("Failed to load class to register with Kryo", e) + } + } def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize)) @@ -80,22 +91,20 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) - // Allow the user to register their own classes by setting spark.kryo.registrator - for (regCls <- registrator) { - logDebug("Running user registrator: " + regCls) - try { - val reg = Class.forName(regCls, true, classLoader).newInstance() - .asInstanceOf[KryoRegistrator] - - // Use the default classloader when calling the user registrator. - Thread.currentThread.setContextClassLoader(classLoader) - reg.registerClasses(kryo) - } catch { - case e: Exception => - throw new SparkException(s"Failed to invoke $regCls", e) - } finally { - Thread.currentThread.setContextClassLoader(oldClassLoader) - } + try { + // Use the default classloader when calling the user registrator. + Thread.currentThread.setContextClassLoader(classLoader) + // Register classes given through spark.kryo.classesToRegister. + classesToRegister.foreach { clazz => kryo.register(clazz) } + // Allow the user to register their own classes by setting spark.kryo.registrator. + userRegistrator + .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]) + .foreach { reg => reg.registerClasses(kryo) } + } catch { + case e: Exception => + throw new SparkException(s"Failed to register classes with Kryo", e) + } finally { + Thread.currentThread.setContextClassLoader(oldClassLoader) } // Register Chill's classes; we do this after our ranges and the user's own classes to let diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index a9144cdd97b8c..ca6e971d227fb 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -17,14 +17,14 @@ package org.apache.spark.serializer -import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream} +import java.io._ import java.nio.ByteBuffer import scala.reflect.ClassTag -import org.apache.spark.SparkEnv +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.util.{ByteBufferInputStream, NextIterator} +import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator} /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 71c08e9d5a8c3..be184464e0ae9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle import org.apache.spark.storage.BlockManagerId import org.apache.spark.{FetchFailed, TaskEndReason} +import org.apache.spark.util.Utils /** * Failed to fetch a shuffle block. The executor catches this exception and propagates it @@ -30,13 +31,22 @@ private[spark] class FetchFailedException( bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, - reduceId: Int) - extends Exception { + reduceId: Int, + message: String, + cause: Throwable = null) + extends Exception(message, cause) { - override def getMessage: String = - "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId) + def this( + bmAddress: BlockManagerId, + shuffleId: Int, + mapId: Int, + reduceId: Int, + cause: Throwable) { + this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) + } - def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId) + def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, + Utils.exceptionString(this)) } /** @@ -46,7 +56,4 @@ private[spark] class MetadataFetchFailedException( shuffleId: Int, reduceId: Int, message: String) - extends FetchFailedException(null, shuffleId, -1, reduceId) { - - override def getMessage: String = message -} + extends FetchFailedException(null, shuffleId, -1, reduceId, message) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 439981d232349..7de2f9cbb2866 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -24,9 +24,10 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConversions._ -import org.apache.spark.{SparkEnv, SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup import org.apache.spark.storage._ @@ -62,11 +63,14 @@ private[spark] trait ShuffleWriterGroup { * each block stored in each file. In order to find the location of a shuffle block, we search the * files within a ShuffleFileGroups associated with the block's reducer. */ - +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData(). private[spark] class FileShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager with Logging { + private val transportConf = SparkTransportConf.fromSparkConf(conf) + private lazy val blockManager = SparkEnv.get.blockManager // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. @@ -181,13 +185,14 @@ class FileShuffleBlockManager(conf: SparkConf) val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId) if (segmentOpt.isDefined) { val segment = segmentOpt.get - return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length) + return new FileSegmentManagedBuffer( + transportConf, segment.file, segment.offset, segment.length) } } throw new IllegalStateException("Failed to find shuffle block: " + blockId) } else { val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(file, 0, file.length) + new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala index 4ab34336d3f01..b292587d37028 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala @@ -20,8 +20,11 @@ package org.apache.spark.shuffle import java.io._ import java.nio.ByteBuffer -import org.apache.spark.SparkEnv -import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer} +import com.google.common.io.ByteStreams + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.storage._ /** @@ -33,11 +36,15 @@ import org.apache.spark.storage._ * as the filename postfix for data file, and ".index" as the filename postfix for index file. * */ +// Note: Changes to the format in this file should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData(). private[spark] -class IndexShuffleBlockManager extends ShuffleBlockManager { +class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager { private lazy val blockManager = SparkEnv.get.blockManager + private val transportConf = SparkTransportConf.fromSparkConf(conf) + /** * Mapping to a single shuffleBlockId with reduce ID 0. * */ @@ -101,10 +108,11 @@ class IndexShuffleBlockManager extends ShuffleBlockManager { val in = new DataInputStream(new FileInputStream(indexFile)) try { - in.skip(blockId.reduceId * 8) + ByteStreams.skipFully(in, blockId.reduceId * 8) val offset = in.readLong() val nextOffset = in.readLong() new FileSegmentManagedBuffer( + transportConf, getDataFile(blockId.shuffleId, blockId.mapId), offset, nextOffset - offset) diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala index 63863cc0250a3..b521f0c7fc77e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala @@ -18,8 +18,7 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer - -import org.apache.spark.network.ManagedBuffer +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId private[spark] diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala index b30e366d06006..292e48314ee10 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala @@ -24,6 +24,10 @@ private[spark] trait ShuffleReader[K, C] { /** Read the combined key-values for this reduce task */ def read(): Iterator[Product2[K, C]] - /** Close this reader */ - def stop(): Unit + /** + * Close this reader. + * TODO: Add this back when we make the ShuffleReader a developer API that others can implement + * (at which point this will likely be necessary). + */ + // def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 6cf9305977a3c..e3e7434df45b0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.util.{Failure, Success, Try} import org.apache.spark._ import org.apache.spark.serializer.Serializer @@ -52,21 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = { + def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Some(block) => { + case Success(block) => { block.asInstanceOf[Iterator[T]] } - case None => { + case Failure(e) => { blockId match { case ShuffleBlockId(shufId, mapId, _) => val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId) + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) case _ => throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block") + "Failed to get block " + blockId + ", which is not a shuffle block", e) } } } @@ -74,7 +75,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { val blockFetcherItr = new ShuffleBlockFetcherIterator( context, - SparkEnv.get.blockTransferService, + SparkEnv.get.blockManager.shuffleClient, blockManager, blocksByAddress, serializer, diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 88a5f1e5ddf58..5baf45db45c17 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -66,7 +66,4 @@ private[spark] class HashShuffleReader[K, C]( aggregatedIter } } - - /** Close this reader */ - override def stop(): Unit = ??? } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 746ed33b54c00..183a30373b28c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -107,7 +107,7 @@ private[spark] class HashShuffleWriter[K, V]( writer.commitAndClose() writer.fileSegment().length } - MapStatus(blockManager.blockManagerId, sizes) + MapStatus(blockManager.shuffleServerId, sizes) } private def revertWrites(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b727438ae7e47..bda30a56d808e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -25,7 +25,7 @@ import org.apache.spark.shuffle.hash.HashShuffleReader private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager { - private val indexShuffleBlockManager = new IndexShuffleBlockManager() + private val indexShuffleBlockManager = new IndexShuffleBlockManager(conf) private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]() /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 927481b72cf4f..d75f9d7311fad 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala deleted file mode 100644 index 5b6d086630834..0000000000000 --- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.storage - -import java.nio.ByteBuffer - - -/** - * An interface for providing data for blocks. - * - * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer. - * - * Aside from unit tests, [[BlockManager]] is the main class that implements this. - */ -private[spark] trait BlockDataProvider { - def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] -} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index a83a3f468ae5f..1f012941c85ab 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -53,6 +53,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { def name = "rdd_" + rddId + "_" + splitIndex } +// Format of the shuffle block ids (including data and index) should be kept in sync with +// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData(). @DeveloperApi case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId @@ -83,9 +85,14 @@ case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId { def name = "input-" + streamId + "-" + uniqueId } -/** Id associated with temporary data managed as blocks. Not serializable. */ -private[spark] case class TempBlockId(id: UUID) extends BlockId { - def name = "temp_" + id +/** Id associated with temporary local data managed as blocks. Not serializable. */ +private[spark] case class TempLocalBlockId(id: UUID) extends BlockId { + def name = "temp_local_" + id +} + +/** Id associated with temporary shuffle data managed as blocks. Not serializable. */ +private[spark] case class TempShuffleBlockId(id: UUID) extends BlockId { + def name = "temp_shuffle_" + id } // Intended only for testing purposes diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 3f5d06e1aeee7..308c59eda594d 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,14 +17,12 @@ package org.apache.spark.storage -import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} +import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} -import scala.concurrent.ExecutionContext.Implicits.global - -import scala.collection.mutable import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{Await, Future} +import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random @@ -35,11 +33,16 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService} +import org.apache.spark.network.shuffle.ExternalShuffleClient +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.network.util.{ConfigProvider, TransportConf} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util._ - private[spark] sealed trait BlockValues private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues @@ -54,6 +57,12 @@ private[spark] class BlockResult( inputMetrics.bytesRead = bytes } +/** + * Manager running on every node (driver and executors) which provides interfaces for putting and + * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). + * + * Note that #initialize() must be called before the BlockManager is usable. + */ private[spark] class BlockManager( executorId: String, actorSystem: ActorSystem, @@ -63,11 +72,11 @@ private[spark] class BlockManager( val conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) + blockTransferService: BlockTransferService, + securityManager: SecurityManager, + numUsableCores: Int) extends BlockDataManager with Logging { - blockTransferService.init(this) - val diskBlockManager = new DiskBlockManager(this, conf) private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] @@ -87,8 +96,37 @@ private[spark] class BlockManager( new TachyonStore(this, tachyonBlockManager) } - val blockManagerId = BlockManagerId( - executorId, blockTransferService.hostName, blockTransferService.port) + private[spark] + val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) + + // Port used by the external shuffle service. In Yarn mode, this may be already be + // set through the Hadoop configuration as the server is launched in the Yarn NM. + private val externalShuffleServicePort = + Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt + + // Check that we're not using external shuffle service with consolidated shuffle files. + if (externalShuffleServiceEnabled + && conf.getBoolean("spark.shuffle.consolidateFiles", false) + && shuffleManager.isInstanceOf[HashShuffleManager]) { + throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated" + + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or " + + " switch to sort-based shuffle.") + } + + var blockManagerId: BlockManagerId = _ + + // Address of the server that serves this executor's shuffle files. This is either an external + // service, or just our own Executor's BlockManager. + private[spark] var shuffleServerId: BlockManagerId = _ + + // Client to read other executors' shuffle files. This is either an external service, or just the + // standard BlockTranserService to directly connect to other Executors. + private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { + val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores) + new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled()) + } else { + blockTransferService + } // Whether to compress broadcast variables that are stored private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true) @@ -118,8 +156,6 @@ private[spark] class BlockManager( private val peerFetchLock = new Object private var lastPeerFetchTime = 0L - initialize() - /* The compression codec to use. Note that the "lazy" val is necessary because we want to delay * the initialization of the compression codec until it is first used. The reason is that a Spark * program could be using a user-defined codec in a third party jar, which is loaded in @@ -138,17 +174,66 @@ private[spark] class BlockManager( conf: SparkConf, mapOutputTracker: MapOutputTracker, shuffleManager: ShuffleManager, - blockTransferService: BlockTransferService) = { + blockTransferService: BlockTransferService, + securityManager: SecurityManager, + numUsableCores: Int) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, mapOutputTracker, shuffleManager, blockTransferService) + conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } /** - * Initialize the BlockManager. Register to the BlockManagerMaster, and start the - * BlockManagerWorker actor. + * Initializes the BlockManager with the given appId. This is not performed in the constructor as + * the appId may not be known at BlockManager instantiation time (in particular for the driver, + * where it is only learned after registration with the TaskScheduler). + * + * This method initializes the BlockTransferService and ShuffleClient, registers with the + * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle + * service if configured. */ - private def initialize(): Unit = { + def initialize(appId: String): Unit = { + blockTransferService.init(this) + shuffleClient.init(appId) + + blockManagerId = BlockManagerId( + executorId, blockTransferService.hostName, blockTransferService.port) + + shuffleServerId = if (externalShuffleServiceEnabled) { + BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) + } else { + blockManagerId + } + master.registerBlockManager(blockManagerId, maxMemory, slaveActor) + + // Register Executors' configuration with the local shuffle service, if one should exist. + if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { + registerWithExternalShuffleServer() + } + } + + private def registerWithExternalShuffleServer() { + logInfo("Registering executor with local external shuffle service.") + val shuffleConfig = new ExecutorShuffleInfo( + diskBlockManager.localDirs.map(_.toString), + diskBlockManager.subDirsPerLocalDir, + shuffleManager.getClass.getName) + + val MAX_ATTEMPTS = 3 + val SLEEP_TIME_SECS = 5 + + for (i <- 1 to MAX_ATTEMPTS) { + try { + // Synchronous and will throw an exception if we cannot connect. + shuffleClient.asInstanceOf[ExternalShuffleClient].registerWithShuffleServer( + shuffleServerId.host, shuffleServerId.port, shuffleServerId.executorId, shuffleConfig) + return + } catch { + case e: Exception if i < MAX_ATTEMPTS => + logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}" + + s" more times after waiting $SLEEP_TIME_SECS seconds...", e) + Thread.sleep(SLEEP_TIME_SECS * 1000) + } + } } /** @@ -212,21 +297,20 @@ private[spark] class BlockManager( } /** - * Interface to get local block data. - * - * @return Some(buffer) if the block exists locally, and None if it doesn't. + * Interface to get local block data. Throws an exception if the block cannot be found or + * cannot be read successfully. */ - override def getBlockData(blockId: String): Option[ManagedBuffer] = { - val bid = BlockId(blockId) - if (bid.isShuffle) { - Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId])) + override def getBlockData(blockId: BlockId): ManagedBuffer = { + if (blockId.isShuffle) { + shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } else { - val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] + val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) + .asInstanceOf[Option[ByteBuffer]] if (blockBytesOpt.isDefined) { val buffer = blockBytesOpt.get - Some(new NioByteBufferManagedBuffer(buffer)) + new NioManagedBuffer(buffer) } else { - None + throw new BlockNotFoundException(blockId.toString) } } } @@ -234,8 +318,8 @@ private[spark] class BlockManager( /** * Put the block locally, using the given storage level. */ - override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = { - putBytes(BlockId(blockId), data.nioByteBuffer(), level) + override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = { + putBytes(blockId, data.nioByteBuffer(), level) } /** @@ -340,17 +424,6 @@ private[spark] class BlockManager( locations } - /** - * A short-circuited method to get blocks directly from disk. This is used for getting - * shuffle blocks. It is safe to do so without a lock on block info since disk store - * never deletes (recent) items. - */ - def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = { - val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) - val is = wrapForCompression(blockId, buf.inputStream()) - Some(serializer.newInstance().deserializeStream(is).asIterator) - } - /** * Get block from local block manager. */ @@ -520,7 +593,7 @@ private[spark] class BlockManager( for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") val data = blockTransferService.fetchBlockSync( - loc.host, loc.port, blockId.toString).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() if (data != null) { if (asBlockResult) { @@ -869,9 +942,9 @@ private[spark] class BlockManager( data.rewind() logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer") blockTransferService.uploadBlockSync( - peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel) - logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms" - .format((System.currentTimeMillis - onePeerStartTime))) + peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel) + logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms" + .format(System.currentTimeMillis - onePeerStartTime)) peersReplicatedTo += peer peersForReplication -= peer replicationFailed = false @@ -1071,7 +1144,8 @@ private[spark] class BlockManager( case _: ShuffleBlockId => compressShuffle case _: BroadcastBlockId => compressBroadcast case _: RDDBlockId => compressRdds - case _: TempBlockId => compressShuffleSpill + case _: TempLocalBlockId => compressShuffleSpill + case _: TempShuffleBlockId => compressShuffle case _ => false } } @@ -1125,7 +1199,11 @@ private[spark] class BlockManager( } def stop(): Unit = { - blockTransferService.stop() + blockTransferService.close() + if (shuffleClient ne blockTransferService) { + // Closing should be idempotent, but maybe not for the NioBlockTransferService. + shuffleClient.close() + } diskBlockManager.stop() actorSystem.stop(slaveActor) blockInfo.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 142285094342c..b177a59c721df 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils @@ -59,15 +60,15 @@ class BlockManagerId private ( def port: Int = port_ - def isDriver: Boolean = (executorId == "") + def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER } - override def writeExternal(out: ObjectOutput) { + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) out.writeUTF(host_) out.writeInt(port_) } - override def readExternal(in: ObjectInput) { + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { executorId_ = in.readUTF() host_ = in.readUTF() port_ = in.readInt() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index d08e1419e3e41..b63c7f191155c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -88,6 +88,10 @@ class BlockManagerMaster( askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId)) } + def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId)) + } + /** * Remove a block from the slaves that have it. This can only be used to remove * blocks that the driver knows about. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 6a06257ed0c08..685b2e11440fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetPeers(blockManagerId) => sender ! getPeers(blockManagerId) + case GetActorSystemHostPortForExecutor(executorId) => + sender ! getActorSystemHostPortForExecutor(executorId) + case GetMemoryStatus => sender ! memoryStatus @@ -203,6 +206,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } } listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) + logInfo(s"Removing block manager $blockManagerId") } private def expireDeadHosts() { @@ -327,20 +331,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { - case Some(manager) => - // A block manager of the same executor already exists. - // This should never happen. Let's just quit. - logError("Got two different block manager registrations on " + id.executorId) - System.exit(1) + case Some(oldId) => + // A block manager of the same executor already exists, so remove it (assumed dead) + logError("Got two different block manager registrations on same executor - " + + s" will replace old one $oldId with new one $id") + removeExecutor(id.executorId) case None => - blockManagerIdByExecutor(id.executorId) = id } - - logInfo("Registering block manager %s with %s RAM".format( - id.hostPort, Utils.bytesToString(maxMemSize))) - - blockManagerInfo(id) = - new BlockManagerInfo(id, time, maxMemSize, slaveActor) + logInfo("Registering block manager %s with %s RAM, %s".format( + id.hostPort, Utils.bytesToString(maxMemSize), id)) + + blockManagerIdByExecutor(id.executorId) = id + + blockManagerInfo(id) = new BlockManagerInfo( + id, System.currentTimeMillis(), maxMemSize, slaveActor) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } @@ -411,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus Seq.empty } } + + /** + * Returns the hostname and port of an executor's actor system, based on the Akka address of its + * BlockManagerSlaveActor. + */ + private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { + for ( + blockManagerId <- blockManagerIdByExecutor.get(executorId); + info <- blockManagerInfo.get(blockManagerId); + host <- info.slaveActor.path.address.host; + port <- info.slaveActor.path.address.port + ) yield { + (host, port) + } + } } @DeveloperApi @@ -457,16 +476,18 @@ private[spark] class BlockManagerInfo( if (_blocks.containsKey(blockId)) { // The block exists on the slave already. - val originalLevel: StorageLevel = _blocks.get(blockId).storageLevel + val blockStatus: BlockStatus = _blocks.get(blockId) + val originalLevel: StorageLevel = blockStatus.storageLevel + val originalMemSize: Long = blockStatus.memSize if (originalLevel.useMemory) { - _remainingMem += memSize + _remainingMem += originalMemSize } } if (storageLevel.isValid) { /* isValid means it is either stored in-memory, on-disk or on-Tachyon. - * But the memSize here indicates the data size in or dropped from memory, + * The memSize here indicates the data size in or dropped from memory, * tachyonSize here indicates the data size in or dropped from Tachyon, * and the diskSize here indicates the data size in or dropped to disk. * They can be both larger than 0, when a block is dropped from memory to disk. @@ -493,7 +514,6 @@ private[spark] class BlockManagerInfo( val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) if (blockStatus.storageLevel.useMemory) { - _remainingMem += blockStatus.memSize logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), Utils.bytesToString(_remainingMem))) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 3db5dd9774ae8..3f32099d08cc9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -21,6 +21,8 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import akka.actor.ActorRef +import org.apache.spark.util.Utils + private[spark] object BlockManagerMessages { ////////////////////////////////////////////////////////////////////////////////// // Messages from the master to slaves. @@ -65,7 +67,7 @@ private[spark] object BlockManagerMessages { def this() = this(null, null, null, 0, 0, 0) // For deserialization only - override def writeExternal(out: ObjectOutput) { + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { blockManagerId.writeExternal(out) out.writeUTF(blockId.name) storageLevel.writeExternal(out) @@ -74,7 +76,7 @@ private[spark] object BlockManagerMessages { out.writeLong(tachyonSize) } - override def readExternal(in: ObjectInput) { + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { blockManagerId = BlockManagerId(in) blockId = BlockId(in.readUTF()) storageLevel = StorageLevel(in) @@ -90,6 +92,8 @@ private[spark] object BlockManagerMessages { case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster + case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster + case class RemoveExecutor(execId: String) extends ToBlockManagerMaster case object StopBlockManagerMaster extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala index 9ef453605f4f1..81f5f2d31dbd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala @@ -17,5 +17,4 @@ package org.apache.spark.storage - class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found") diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index a715594f198c2..58fba54710510 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -38,12 +38,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon extends Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) + private[spark] + val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64) /* Create one local directory for each path mentioned in spark.local.dir; then, inside this * directory, create multiple subdirectories that we will hash files into, in order to avoid * having really large inodes at the top level. */ - val localDirs: Array[File] = createLocalDirs(conf) + private[spark] val localDirs: Array[File] = createLocalDirs(conf) if (localDirs.isEmpty) { logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) @@ -52,6 +53,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon addShutdownHook() + /** Looks up a file by hashing it into one of our local subdirectories. */ + // This method should be kept in sync with + // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile(). def getFile(filename: String): File = { // Figure out which local directory it hashes to, and which subdirectory in that val hash = Utils.nonNegativeHash(filename) @@ -98,11 +102,20 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon getAllFiles().map(f => BlockId(f.getName)) } - /** Produces a unique block id and File suitable for intermediate results. */ - def createTempBlock(): (TempBlockId, File) = { - var blockId = new TempBlockId(UUID.randomUUID()) + /** Produces a unique block id and File suitable for storing local intermediate results. */ + def createTempLocalBlock(): (TempLocalBlockId, File) = { + var blockId = new TempLocalBlockId(UUID.randomUUID()) while (getFile(blockId).exists()) { - blockId = new TempBlockId(UUID.randomUUID()) + blockId = new TempLocalBlockId(UUID.randomUUID()) + } + (blockId, getFile(blockId)) + } + + /** Produces a unique block id and File suitable for storing shuffled intermediate results. */ + def createTempShuffleBlock(): (TempShuffleBlockId, File) = { + var blockId = new TempShuffleBlockId(UUID.randomUUID()) + while (getFile(blockId).exists()) { + blockId = new TempShuffleBlockId(UUID.randomUUID()) } (blockId, getFile(blockId)) } @@ -140,7 +153,6 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon } private def addShutdownHook() { - localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir)) Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") { override def run(): Unit = Utils.logUncaughtExceptions { logDebug("Shutdown hook called") @@ -151,13 +163,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon /** Cleanup local dirs and stop shuffle sender. */ private[spark] def stop() { - localDirs.foreach { localDir => - if (localDir.isDirectory() && localDir.exists()) { - try { - if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) - } catch { - case e: Exception => - logError(s"Exception while deleting local spark dir: $localDir", e) + // Only perform cleanup if an external service is not serving our shuffle files. + if (!blockManager.externalShuffleServiceEnabled) { + localDirs.foreach { localDir => + if (localDir.isDirectory() && localDir.exists()) { + try { + if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir) + } catch { + case e: Exception => + logError(s"Exception while deleting local spark dir: $localDir", e) + } } } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index e9304f6bb45d0..8dadf6794039e 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{File, FileOutputStream, RandomAccessFile} +import java.io.{IOException, File, FileOutputStream, RandomAccessFile} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode @@ -73,7 +73,21 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc val startTime = System.currentTimeMillis val file = diskManager.getFile(blockId) val outputStream = new FileOutputStream(file) - blockManager.dataSerializeStream(blockId, outputStream, values) + try { + try { + blockManager.dataSerializeStream(blockId, outputStream, values) + } finally { + // Close outputStream here because it should be closed before file is deleted. + outputStream.close() + } + } catch { + case e: Throwable => + if (file.exists()) { + file.delete() + } + throw e + } + val length = file.length val timeTaken = System.currentTimeMillis - startTime @@ -96,7 +110,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc // For small files, directly read rather than memory map if (length < minMemoryMapBytes) { val buf = ByteBuffer.allocate(length.toInt) - channel.read(buf, offset) + channel.position(offset) + while (buf.remaining() != 0) { + if (channel.read(buf) == -1) { + throw new IOException("Reached EOF before filling buffer\n" + + s"offset=$offset\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}") + } + } buf.flip() Some(buf) } else { diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala index 0a09c24d61879..71305a46bf570 100644 --- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala @@ -56,6 +56,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) (maxMemory * unrollFraction).toLong } + // Initial memory to request before unrolling any block + private val unrollMemoryThreshold: Long = + conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) + + if (maxMemory < unrollMemoryThreshold) { + logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " + + s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " + + s"memory. Please configure Spark with more memory.") + } + logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory))) /** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */ @@ -132,8 +142,6 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) PutResult(res.size, res.data, droppedBlocks) case Right(iteratorValues) => // Not enough space to unroll this block; drop to disk if applicable - logWarning(s"Not enough space to store block $blockId in memory! " + - s"Free memory is $freeMemory bytes.") if (level.useDisk && allowPersistToDisk) { logWarning(s"Persisting block $blockId to disk instead.") val res = blockManager.diskStore.putIterator(blockId, iteratorValues, level, returnValues) @@ -215,7 +223,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Whether there is still enough memory for us to continue unrolling this block var keepUnrolling = true // Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing. - val initialMemoryThreshold = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024) + val initialMemoryThreshold = unrollMemoryThreshold // How often to check whether we need to request more memory val memoryCheckPeriod = 16 // Memory currently reserved by this thread for this particular unrolling operation @@ -230,6 +238,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) // Request enough memory to begin unrolling keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold) + if (!keepUnrolling) { + logWarning(s"Failed to reserve initial memory threshold of " + + s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.") + } + // Unroll this block safely, checking whether we have exceeded our threshold periodically try { while (values.hasNext && keepUnrolling) { @@ -265,6 +278,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) Left(vector.toArray) } else { // We ran out of space while unrolling the values for this block + logUnrollFailureMessage(blockId, vector.estimateSize()) Right(vector.iterator ++ values) } @@ -424,7 +438,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Reserve additional memory for unrolling blocks used by this thread. * Return whether the request is granted. */ - private[spark] def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { + def reserveUnrollMemoryForThisThread(memory: Long): Boolean = { accountingLock.synchronized { val granted = freeMemory > currentUnrollMemory + memory if (granted) { @@ -439,7 +453,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) * Release memory used by this thread for unrolling blocks. * If the amount is not specified, remove the current thread's allocation altogether. */ - private[spark] def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { + def releaseUnrollMemoryForThisThread(memory: Long = -1L): Unit = { val threadId = Thread.currentThread().getId accountingLock.synchronized { if (memory < 0) { @@ -457,16 +471,50 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long) /** * Return the amount of memory currently occupied for unrolling blocks across all threads. */ - private[spark] def currentUnrollMemory: Long = accountingLock.synchronized { + def currentUnrollMemory: Long = accountingLock.synchronized { unrollMemoryMap.values.sum } /** * Return the amount of memory currently occupied for unrolling blocks by this thread. */ - private[spark] def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { + def currentUnrollMemoryForThisThread: Long = accountingLock.synchronized { unrollMemoryMap.getOrElse(Thread.currentThread().getId, 0L) } + + /** + * Return the number of threads currently unrolling blocks. + */ + def numThreadsUnrolling: Int = accountingLock.synchronized { unrollMemoryMap.keys.size } + + /** + * Log information about current memory usage. + */ + def logMemoryUsage(): Unit = { + val blocksMemory = currentMemory + val unrollMemory = currentUnrollMemory + val totalMemory = blocksMemory + unrollMemory + logInfo( + s"Memory use = ${Utils.bytesToString(blocksMemory)} (blocks) + " + + s"${Utils.bytesToString(unrollMemory)} (scratch space shared across " + + s"$numThreadsUnrolling thread(s)) = ${Utils.bytesToString(totalMemory)}. " + + s"Storage limit = ${Utils.bytesToString(maxMemory)}." + ) + } + + /** + * Log a warning for failing to unroll a block. + * + * @param blockId ID of the block we are trying to unroll. + * @param finalVectorSize Final size of the vector before unrolling failed. + */ + def logUnrollFailureMessage(blockId: BlockId, finalVectorSize: Long): Unit = { + logWarning( + s"Not enough space to cache $blockId in memory! " + + s"(computed ${Utils.bytesToString(finalVectorSize)} so far)" + ) + logMemoryUsage() + } } private[spark] case class ResultWithDroppedBlocks( diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 71b276b5f18e4..83170f7c5a4ab 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -19,15 +19,15 @@ package org.apache.spark.storage import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashSet -import scala.collection.mutable.Queue +import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.util.{Failure, Success, Try} -import org.apache.spark.{TaskContext, Logging} -import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService} +import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.serializer.Serializer -import org.apache.spark.util.Utils - +import org.apache.spark.util.{CompletionIterator, Utils} /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block @@ -40,8 +40,8 @@ import org.apache.spark.util.Utils * using too much memory. * * @param context [[TaskContext]], used for metrics update - * @param blockTransferService [[BlockTransferService]] for fetching remote blocks - * @param blockManager [[BlockManager]] for reading local blocks + * @param shuffleClient [[ShuffleClient]] for fetching remote blocks + * @param blockManager [[BlockManager]] for reading local blocks * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. @@ -51,12 +51,12 @@ import org.apache.spark.util.Utils private[spark] final class ShuffleBlockFetcherIterator( context: TaskContext, - blockTransferService: BlockTransferService, + shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { import ShuffleBlockFetcherIterator._ @@ -88,17 +88,53 @@ final class ShuffleBlockFetcherIterator( */ private[this] val results = new LinkedBlockingQueue[FetchResult] - // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that - // the number of bytes in flight is limited to maxBytesInFlight + /** + * Current [[FetchResult]] being processed. We track this so we can release the current buffer + * in case of a runtime exception when processing the current buffer. + */ + @volatile private[this] var currentResult: FetchResult = null + + /** + * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that + * the number of bytes in flight is limited to maxBytesInFlight. + */ private[this] val fetchRequests = new Queue[FetchRequest] - // Current bytes in flight from our requests + /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + /** + * Whether the iterator is still active. If isZombie is true, the callback interface will no + * longer place fetched blocks into [[results]]. + */ + @volatile private[this] var isZombie = false + initialize() + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + // Release the current buffer if necessary + currentResult match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => + } + + // Release buffers in the results queue + val iter = results.iterator() + while (iter.hasNext) { + val result = iter.next() + result match { + case SuccessFetchResult(_, _, buf) => buf.release() + case _ => + } + } + } + private[this] def sendRequest(req: FetchRequest) { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) @@ -108,26 +144,26 @@ final class ShuffleBlockFetcherIterator( val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val blockIds = req.blocks.map(_._1.toString) - blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds, + val address = req.address + shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - results.put(new FetchResult(BlockId(blockId), sizeMap(blockId), - () => serializer.newInstance().deserializeStream( - blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator - )) - shuffleMetrics.remoteBytesRead += data.size - shuffleMetrics.remoteBlocksFetched += 1 - logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) + override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + // Only add the buffer to results queue if the iterator is not zombie, + // i.e. cleanup() has not been called yet. + if (!isZombie) { + // Increment the ref count because we need to pass this to a different thread. + // This needs to be released after use. + buf.retain() + results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) + shuffleMetrics.remoteBytesRead += buf.size + shuffleMetrics.remoteBlocksFetched += 1 + } + logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - override def onBlockFetchFailure(e: Throwable): Unit = { + override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - // Note that there is a chance that some blocks have been fetched successfully, but we - // still add them to the failed queue. This is fine because when the caller see a - // FetchFailedException, it is going to fail the entire task anyway. - for ((blockId, size) <- req.blocks) { - results.put(new FetchResult(blockId, -1, null)) - } + results.put(new FailureFetchResult(BlockId(blockId), e)) } } ) @@ -138,7 +174,7 @@ final class ShuffleBlockFetcherIterator( // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. @@ -148,7 +184,7 @@ final class ShuffleBlockFetcherIterator( var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size - if (address == blockManager.blockManagerId) { + if (address.executorId == blockManager.blockManagerId.executorId) { // Filter out zero-sized blocks localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) numBlocksToFetch += localBlocks.size @@ -185,26 +221,34 @@ final class ShuffleBlockFetcherIterator( remoteRequests } + /** + * Fetch the local blocks while we are fetching remote blocks. This is ok because + * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ private[this] def fetchLocalBlocks() { - // Get the local blocks while remote blocks are being fetched. Note that it's okay to do - // these all at once because they will just memory-map some files, so they won't consume - // any memory that might exceed our maxBytesInFlight - for (id <- localBlocks) { + val iter = localBlocks.iterator + while (iter.hasNext) { + val blockId = iter.next() try { + val buf = blockManager.getBlockData(blockId) shuffleMetrics.localBlocksFetched += 1 - results.put(new FetchResult( - id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get)) - logDebug("Got local block " + id) + buf.retain() + results.put(new SuccessFetchResult(blockId, 0, buf)) } catch { case e: Exception => + // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FetchResult(id, -1, null)) + results.put(new FailureFetchResult(blockId, e)) return } } } private[this] def initialize(): Unit = { + // Add a task completion callback (called in both success case and failure case) to cleanup. + context.addTaskCompletionListener(_ => cleanup()) + // Split local and remote blocks. val remoteRequests = splitLocalRemoteBlocks() // Add the remote requests into our queue in a random order @@ -221,26 +265,44 @@ final class ShuffleBlockFetcherIterator( // Get Local Blocks fetchLocalBlocks() - logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms") + logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime)) } override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Option[Iterator[Any]]) = { + override def next(): (BlockId, Try[Iterator[Any]]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() - val result = results.take() + currentResult = results.take() + val result = currentResult val stopFetchWait = System.currentTimeMillis() shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait) - if (!result.failed) { - bytesInFlight -= result.size + + result match { + case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case _ => } // Send fetch requests up to maxBytesInFlight while (fetchRequests.nonEmpty && (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) { sendRequest(fetchRequests.dequeue()) } - (result.blockId, if (result.failed) None else Some(result.deserialize())) + + val iteratorTry: Try[Iterator[Any]] = result match { + case FailureFetchResult(_, e) => Failure(e) + case SuccessFetchResult(blockId, _, buf) => { + val is = blockManager.wrapForCompression(blockId, buf.createInputStream()) + val iter = serializer.newInstance().deserializeStream(is).asIterator + Success(CompletionIterator[Any, Iterator[Any]](iter, { + // Once the iterator is exhausted, release the buffer and set currentResult to null + // so we don't release it again in cleanup. + currentResult = null + buf.release() + })) + } + } + + (result.blockId, iteratorTry) } } @@ -254,18 +316,35 @@ object ShuffleBlockFetcherIterator { * @param blocks Sequence of tuple, where the first element is the block id, * and the second element is the estimated size, used to calculate bytesInFlight. */ - class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) { + case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { val size = blocks.map(_._2).sum } /** - * Result of a fetch from a remote block. A failure is represented as size == -1. + * Result of a fetch from a remote block. + */ + private[storage] sealed trait FetchResult { + val blockId: BlockId + } + + /** + * Result of a fetch from a remote block successfully. * @param blockId block id * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. - * @param deserialize closure to return the result in the form of an Iterator. + * @param buf [[ManagedBuffer]] for the content. */ - class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) { - def failed: Boolean = size == -1 + private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + extends FetchResult { + require(buf != null) + require(size >= 0) } + + /** + * Result of a fetch from a remote block unsuccessfully. + * @param blockId block id + * @param e the failure exception + */ + private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + extends FetchResult } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 1e35abaab5353..56edc4fe2e4ad 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -97,12 +98,12 @@ class StorageLevel private( ret } - override def writeExternal(out: ObjectOutput) { + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeByte(toInt) out.writeByte(_replication) } - override def readExternal(in: ObjectInput) { + override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { val flags = in.readByte() _useDisk = (flags & 8) != 0 _useMemory = (flags & 4) != 0 diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index d9066f766476e..def49e80a3605 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import scala.collection.mutable +import org.apache.spark.SparkContext import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ @@ -59,10 +60,9 @@ class StorageStatusListener extends SparkListener { val info = taskEnd.taskInfo val metrics = taskEnd.taskMetrics if (info != null && metrics != null) { - val execId = formatExecutorId(info.executorId) val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]()) if (updatedBlocks.length > 0) { - updateStorageStatus(execId, updatedBlocks) + updateStorageStatus(info.executorId, updatedBlocks) } } } @@ -88,13 +88,4 @@ class StorageStatusListener extends SparkListener { } } - /** - * In the local mode, there is a discrepancy between the executor ID according to the - * task ("localhost") and that according to SparkEnv (""). In the UI, this - * results in duplicate rows for the same executor. Thus, in this mode, we aggregate - * these two rows and use the executor ID of "" to be consistent. - */ - def formatExecutorId(execId: String): String = { - if (execId == "localhost") "" else execId - } } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index 6908a59a79e60..af873034215a9 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -148,6 +148,7 @@ private[spark] class TachyonBlockManager( logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } + client.close() } }) } diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala index 932b5616043b4..233d1e2b7c616 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala @@ -20,6 +20,7 @@ package org.apache.spark.storage import java.io.IOException import java.nio.ByteBuffer +import com.google.common.io.ByteStreams import tachyon.client.{ReadType, WriteType} import org.apache.spark.Logging @@ -105,25 +106,19 @@ private[spark] class TachyonStore( return None } val is = file.getInStream(ReadType.CACHE) - var buffer: ByteBuffer = null + assert (is != null) try { - if (is != null) { - val size = file.length - val bs = new Array[Byte](size.asInstanceOf[Int]) - val fetchSize = is.read(bs, 0, size.asInstanceOf[Int]) - buffer = ByteBuffer.wrap(bs) - if (fetchSize != size) { - logWarning(s"Failed to fetch the block $blockId from Tachyon: Size $size " + - s"is not equal to fetched size $fetchSize") - return None - } - } + val size = file.length + val bs = new Array[Byte](size.asInstanceOf[Int]) + ByteStreams.readFully(is, bs) + Some(ByteBuffer.wrap(bs)) } catch { case ioe: IOException => logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe) - return None + None + } finally { + is.close() } - Some(buffer) } override def contains(blockId: BlockId): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala new file mode 100644 index 0000000000000..27ba9e18237b5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import java.util.{Timer, TimerTask} + +import org.apache.spark._ + +/** + * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the + * status of active stages from `sc.statusTracker` periodically, the progress bar will be showed + * up after the stage has ran at least 500ms. If multiple stages run in the same time, the status + * of them will be combined together, showed in one line. + */ +private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging { + + // Carrige return + val CR = '\r' + // Update period of progress bar, in milliseconds + val UPDATE_PERIOD = 200L + // Delay to show up a progress bar, in milliseconds + val FIRST_DELAY = 500L + + // The width of terminal + val TerminalWidth = if (!sys.env.getOrElse("COLUMNS", "").isEmpty) { + sys.env.get("COLUMNS").get.toInt + } else { + 80 + } + + var lastFinishTime = 0L + var lastUpdateTime = 0L + var lastProgressBar = "" + + // Schedule a refresh thread to run periodically + private val timer = new Timer("refresh progress", true) + timer.schedule(new TimerTask{ + override def run() { + refresh() + } + }, FIRST_DELAY, UPDATE_PERIOD) + + /** + * Try to refresh the progress bar in every cycle + */ + private def refresh(): Unit = synchronized { + val now = System.currentTimeMillis() + if (now - lastFinishTime < FIRST_DELAY) { + return + } + val stageIds = sc.statusTracker.getActiveStageIds() + val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1) + .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId()) + if (stages.size > 0) { + show(now, stages.take(3)) // display at most 3 stages in same time + } + } + + /** + * Show progress bar in console. The progress bar is displayed in the next line + * after your last output, keeps overwriting itself to hold in one line. The logging will follow + * the progress bar, then progress bar will be showed in next line without overwrite logs. + */ + private def show(now: Long, stages: Seq[SparkStageInfo]) { + val width = TerminalWidth / stages.size + val bar = stages.map { s => + val total = s.numTasks() + val header = s"[Stage ${s.stageId()}:" + val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]" + val w = width - header.size - tailer.size + val bar = if (w > 0) { + val percent = w * s.numCompletedTasks() / total + (0 until w).map { i => + if (i < percent) "=" else if (i == percent) ">" else " " + }.mkString("") + } else { + "" + } + header + bar + tailer + }.mkString("") + + // only refresh if it's changed of after 1 minute (or the ssh connection will be closed + // after idle some time) + if (bar != lastProgressBar || now - lastUpdateTime > 60 * 1000L) { + System.err.print(CR + bar) + lastUpdateTime = now + } + lastProgressBar = bar + } + + /** + * Clear the progress bar if showed. + */ + private def clear() { + if (!lastProgressBar.isEmpty) { + System.err.printf(CR + " " * TerminalWidth + CR) + lastProgressBar = "" + } + } + + /** + * Mark all the stages as finished, clear the progress bar if showed, then the progress will not + * interweave with output of jobs. + */ + def finishAll(): Unit = synchronized { + clear() + lastFinishTime = System.currentTimeMillis() + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index cccd59d122a92..176907dffa46a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -21,60 +21,46 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.env.EnvironmentTab -import org.apache.spark.ui.exec.ExecutorsTab -import org.apache.spark.ui.jobs.JobProgressTab -import org.apache.spark.ui.storage.StorageTab +import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} +import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} +import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab} +import org.apache.spark.ui.storage.{StorageListener, StorageTab} /** * Top level user interface for a Spark application. */ -private[spark] class SparkUI( - val sc: SparkContext, +private[spark] class SparkUI private ( + val sc: Option[SparkContext], val conf: SparkConf, val securityManager: SecurityManager, - val listenerBus: SparkListenerBus, + val environmentListener: EnvironmentListener, + val storageStatusListener: StorageStatusListener, + val executorsListener: ExecutorsListener, + val jobProgressListener: JobProgressListener, + val storageListener: StorageListener, var appName: String, - val basePath: String = "") + val basePath: String) extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging { - def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName) - def this(conf: SparkConf, listenerBus: SparkListenerBus, appName: String, basePath: String) = - this(null, conf, new SecurityManager(conf), listenerBus, appName, basePath) - - def this( - conf: SparkConf, - securityManager: SecurityManager, - listenerBus: SparkListenerBus, - appName: String, - basePath: String) = - this(null, conf, securityManager, listenerBus, appName, basePath) - - // If SparkContext is not provided, assume the associated application is not live - val live = sc != null - - // Maintain executor storage status through Spark events - val storageStatusListener = new StorageStatusListener - - initialize() + val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false) /** Initialize all components of the server. */ def initialize() { - listenerBus.addListener(storageStatusListener) - val jobProgressTab = new JobProgressTab(this) - attachTab(jobProgressTab) + attachTab(new JobsTab(this)) + val stagesTab = new StagesTab(this) + attachTab(stagesTab) attachTab(new StorageTab(this)) attachTab(new EnvironmentTab(this)) attachTab(new ExecutorsTab(this)) attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) - attachHandler(createRedirectHandler("/", "/stages", basePath = basePath)) + attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath)) attachHandler( - createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest)) - if (live) { - sc.env.metricsSystem.getServletHandlers.foreach(attachHandler) - } + createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest)) + // If the UI is live, then serve + sc.foreach { _.env.metricsSystem.getServletHandlers.foreach(attachHandler) } } + initialize() def getAppName = appName @@ -83,11 +69,6 @@ private[spark] class SparkUI( appName = name } - /** Register the given listener with the listener bus. */ - def registerListener(listener: SparkListener) { - listenerBus.addListener(listener) - } - /** Stop the server behind this web interface. Only valid after bind(). */ override def stop() { super.stop() @@ -116,4 +97,60 @@ private[spark] object SparkUI { def getUIPort(conf: SparkConf): Int = { conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT) } + + def createLiveUI( + sc: SparkContext, + conf: SparkConf, + listenerBus: SparkListenerBus, + jobProgressListener: JobProgressListener, + securityManager: SecurityManager, + appName: String): SparkUI = { + create(Some(sc), conf, listenerBus, securityManager, appName, + jobProgressListener = Some(jobProgressListener)) + } + + def createHistoryUI( + conf: SparkConf, + listenerBus: SparkListenerBus, + securityManager: SecurityManager, + appName: String, + basePath: String): SparkUI = { + create(None, conf, listenerBus, securityManager, appName, basePath) + } + + /** + * Create a new Spark UI. + * + * @param sc optional SparkContext; this can be None when reconstituting a UI from event logs. + * @param jobProgressListener if supplied, this JobProgressListener will be used; otherwise, the + * web UI will create and register its own JobProgressListener. + */ + private def create( + sc: Option[SparkContext], + conf: SparkConf, + listenerBus: SparkListenerBus, + securityManager: SecurityManager, + appName: String, + basePath: String = "", + jobProgressListener: Option[JobProgressListener] = None): SparkUI = { + + val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse { + val listener = new JobProgressListener(conf) + listenerBus.addListener(listener) + listener + } + + val environmentListener = new EnvironmentListener + val storageStatusListener = new StorageStatusListener + val executorsListener = new ExecutorsListener(storageStatusListener) + val storageListener = new StorageListener(storageStatusListener) + + listenerBus.addListener(environmentListener) + listenerBus.addListener(storageStatusListener) + listenerBus.addListener(executorsListener) + listenerBus.addListener(storageListener) + + new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, + executorsListener, _jobProgressListener, storageListener, appName, basePath) + } } diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 9ced9b8107ebf..6f446c5a95a0a 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -24,11 +24,28 @@ private[spark] object ToolTips { scheduler delay is large, consider decreasing the size of tasks or decreasing the size of task results.""" + val TASK_DESERIALIZATION_TIME = + """Time spent deserializating the task closure on the executor.""" + val INPUT = "Bytes read from Hadoop or from Spark storage." + val OUTPUT = "Bytes written to Hadoop." + val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage." val SHUFFLE_READ = """Bytes read from remote executors. Typically less than shuffle write bytes because this does not include shuffle data read locally.""" + + val GETTING_RESULT_TIME = + """Time that the driver spends fetching task results from workers. If this is large, consider + decreasing the amount of data returned from each task.""" + + val RESULT_SERIALIZATION_TIME = + """Time spent serializing the task result on the executor before sending it back to the + driver.""" + + val GC_TIME = + """Time that the executor spent paused for Java garbage collection while the task was + running.""" } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index f0006b42aee4f..09079bbd43f6f 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -20,12 +20,14 @@ package org.apache.spark.ui import java.text.SimpleDateFormat import java.util.{Locale, Date} -import scala.xml.Node +import scala.xml.{Node, Text} + import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ private[spark] object UIUtils extends Logging { - val TABLE_CLASS = "table table-bordered table-striped table-condensed sortable" + val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable" + val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped" // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. private val dateFormat = new ThreadLocal[SimpleDateFormat]() { @@ -159,6 +161,8 @@ private[spark] object UIUtils extends Logging { + + } /** Returns a spark page with correctly formatted headers */ @@ -166,14 +170,19 @@ private[spark] object UIUtils extends Logging { title: String, content: => Seq[Node], activeTab: SparkUITab, - refreshInterval: Option[Int] = None): Seq[Node] = { + refreshInterval: Option[Int] = None, + helpText: Option[String] = None): Seq[Node] = { val appName = activeTab.appName + val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..." val header = activeTab.headerTabs.map { tab =>
  • - {tab.name} + {tab.name}
  • } + val helpButton: Seq[Node] = helpText.map { helpText => + (?) + }.getOrElse(Seq.empty) @@ -187,7 +196,9 @@ private[spark] object UIUtils extends Logging { - +
    @@ -195,6 +206,7 @@ private[spark] object UIUtils extends Logging {

    {title} + {helpButton}

    @@ -216,8 +228,10 @@ private[spark] object UIUtils extends Logging {

    - + + + {title}

    @@ -233,35 +247,64 @@ private[spark] object UIUtils extends Logging { headers: Seq[String], generateDataRow: T => Seq[Node], data: Iterable[T], - fixedWidth: Boolean = false): Seq[Node] = { + fixedWidth: Boolean = false, + id: Option[String] = None, + headerClasses: Seq[String] = Seq.empty, + stripeRowsWithCss: Boolean = true): Seq[Node] = { - var listingTableClass = TABLE_CLASS - if (fixedWidth) { - listingTableClass += " table-fixed" - } + val listingTableClass = if (stripeRowsWithCss) TABLE_CLASS_STRIPED else TABLE_CLASS_NOT_STRIPED val colWidth = 100.toDouble / headers.size val colWidthAttr = if (fixedWidth) colWidth + "%" else "" - val headerRow: Seq[Node] = { - // if none of the headers have "\n" in them - if (headers.forall(!_.contains("\n"))) { - // represent header as simple text - headers.map(h => {h}) + + def getClass(index: Int): String = { + if (index < headerClasses.size) { + headerClasses(index) } else { - // represent header text as list while respecting "\n" - headers.map { case h => - -
      - { h.split("\n").map { case t =>
    • {t}
    • } } -
    - - } + "" + } + } + + val newlinesInHeader = headers.exists(_.contains("\n")) + def getHeaderContent(header: String): Seq[Node] = { + if (newlinesInHeader) { +
      + { header.split("\n").map { case t =>
    • {t}
    • } } +
    + } else { + Text(header) + } + } + + val headerRow: Seq[Node] = { + headers.view.zipWithIndex.map { x => + {getHeaderContent(x._1)} } } - +
    {headerRow} {data.map(r => generateDataRow(r))}
    } + + def makeProgressBar( + started: Int, + completed: Int, + failed: Int, + skipped:Int, + total: Int): Seq[Node] = { + val completeWidth = "width: %s%%".format((completed.toDouble/total)*100) + val startWidth = "width: %s%%".format((started.toDouble/total)*100) + +
    + + {completed}/{total} + { if (failed > 0) s"($failed failed)" } + { if (skipped > 0) s"($skipped skipped)" } + +
    +
    +
    + } } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 5d88ca403a674..9be65a4a39a09 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -82,7 +82,7 @@ private[spark] abstract class WebUI( } /** Detach a handler from this UI. */ - def detachHandler(handler: ServletContextHandler) { + protected def detachHandler(handler: ServletContextHandler) { handlers -= handler serverInfo.foreach { info => info.rootHandler.removeHandler(handler) diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 0d158fbe638d3..f62260c6f6e1d 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -22,10 +22,8 @@ import org.apache.spark.scheduler._ import org.apache.spark.ui._ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") { - val listener = new EnvironmentListener - + val listener = parent.environmentListener attachPage(new EnvironmentPage(this)) - parent.registerListener(listener) } /** diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala new file mode 100644 index 0000000000000..c82730f524eb7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.exec + +import java.net.URLDecoder +import javax.servlet.http.HttpServletRequest + +import scala.util.Try +import scala.xml.{Text, Node} + +import org.apache.spark.ui.{UIUtils, WebUIPage} + +private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") { + + private val sc = parent.sc + + def render(request: HttpServletRequest): Seq[Node] = { + val executorId = Option(request.getParameter("executorId")).map { + executorId => + // Due to YARN-2844, "" in the url will be encoded to "%25253Cdriver%25253E" when + // running in yarn-cluster mode. `request.getParameter("executorId")` will return + // "%253Cdriver%253E". Therefore we need to decode it until we get the real id. + var id = executorId + var decodedId = URLDecoder.decode(id, "UTF-8") + while (id != decodedId) { + id = decodedId + decodedId = URLDecoder.decode(id, "UTF-8") + } + id + }.getOrElse { + return Text(s"Missing executorId parameter") + } + val time = System.currentTimeMillis() + val maybeThreadDump = sc.get.getExecutorThreadDump(executorId) + + val content = maybeThreadDump.map { threadDump => + val dumpRows = threadDump.map { thread => + + } + +
    +

    Updated at {UIUtils.formatDate(time)}

    + { + // scalastyle:off +

    + Expand All +

    +

    + // scalastyle:on + } +
    {dumpRows}
    +
    + }.getOrElse(Text("Error fetching thread dump")) + UIUtils.headerSparkPage(s"Thread dump for executor $executorId", content, parent) + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b0e3bb3b552fd..363cb96de7998 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.exec +import java.net.URLEncoder import javax.servlet.http.HttpServletRequest import scala.xml.Node @@ -41,7 +42,10 @@ private case class ExecutorSummaryInfo( totalShuffleWrite: Long, maxMemory: Long) -private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { +private[ui] class ExecutorsPage( + parent: ExecutorsTab, + threadDumpEnabled: Boolean) + extends WebUIPage("") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -53,7 +57,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { val execInfoSorted = execInfo.sortBy(_.id) val execTable = - +
    @@ -75,6 +79,7 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { Shuffle Write + {if (threadDumpEnabled) else Seq.empty} {execInfoSorted.map(execRow)} @@ -133,6 +138,16 @@ private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { + { + if (threadDumpEnabled) { + val encodedId = URLEncoder.encode(info.id, "UTF-8") + + } else { + Seq.empty + } + } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 61eb111cd9100..dd1c2b78c4094 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -26,10 +26,15 @@ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.{SparkUI, SparkUITab} private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") { - val listener = new ExecutorsListener(parent.storageStatusListener) + val listener = parent.executorsListener + val sc = parent.sc + val threadDumpEnabled = + sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true) - attachPage(new ExecutorsPage(this)) - parent.registerListener(listener) + attachPage(new ExecutorsPage(this, threadDumpEnabled)) + if (threadDumpEnabled) { + attachPage(new ExecutorThreadDumpPage(this)) + } } /** @@ -43,20 +48,21 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val executorToTasksFailed = HashMap[String, Int]() val executorToDuration = HashMap[String, Long]() val executorToInputBytes = HashMap[String, Long]() + val executorToOutputBytes = HashMap[String, Long]() val executorToShuffleRead = HashMap[String, Long]() val executorToShuffleWrite = HashMap[String, Long]() def storageStatusList = storageStatusListener.storageStatusList override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { - val eid = formatExecutorId(taskStart.taskInfo.executorId) + val eid = taskStart.taskInfo.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 } override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { val info = taskEnd.taskInfo if (info != null) { - val eid = formatExecutorId(info.executorId) + val eid = info.executorId executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration taskEnd.reason match { @@ -73,6 +79,10 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp executorToInputBytes(eid) = executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead } + metrics.outputMetrics.foreach { outputMetrics => + executorToOutputBytes(eid) = + executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten + } metrics.shuffleReadMetrics.foreach { shuffleRead => executorToShuffleRead(eid) = executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead @@ -85,6 +95,4 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp } } - // This addresses executor ID inconsistencies in the local mode - private def formatExecutorId(execId: String) = storageStatusListener.formatExecutorId(execId) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala new file mode 100644 index 0000000000000..ea2d187a0e8e4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.xml.{Node, NodeSeq} + +import javax.servlet.http.HttpServletRequest + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.jobs.UIData.JobUIData + +/** Page showing list of all ongoing and recently finished jobs */ +private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { + private val startTime: Option[Long] = parent.sc.map(_.startTime) + private val listener = parent.listener + + private def jobsTable(jobs: Seq[JobUIData]): Seq[Node] = { + val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined) + + val columns: Seq[Node] = { + + + + + + + } + + def makeRow(job: JobUIData): Seq[Node] = { + val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max) + val lastStageData = lastStageInfo.flatMap { s => + listener.stageIdToData.get((s.stageId, s.attemptId)) + } + val isComplete = job.status == JobExecutionStatus.SUCCEEDED + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("") + val duration: Option[Long] = { + job.startTime.map { start => + val end = job.endTime.getOrElse(System.currentTimeMillis()) + end - start + } + } + val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") + val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown") + val detailUrl = + "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId) + + + + + + + + + } + +
    Executor ID Address Thread Dump
    {Utils.bytesToString(info.totalShuffleWrite)} + Thread Dump +
    {if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"}DescriptionSubmittedDurationStages: Succeeded/TotalTasks (for all stages): Succeeded/Total
    + {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} + +
    {lastStageDescription}
    + {lastStageName} +
    + {formattedSubmissionTime} + {formattedDuration} + {job.completedStageIndices.size}/{job.stageIds.size - job.numSkippedStages} + {if (job.numFailedStages > 0) s"(${job.numFailedStages} failed)"} + {if (job.numSkippedStages > 0) s"(${job.numSkippedStages} skipped)"} + + {UIUtils.makeProgressBar(started = job.numActiveTasks, completed = job.numCompletedTasks, + failed = job.numFailedTasks, skipped = job.numSkippedTasks, + total = job.numTasks - job.numSkippedTasks)} +
    + {columns} + + {jobs.map(makeRow)} + +
    + } + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val activeJobs = listener.activeJobs.values.toSeq + val completedJobs = listener.completedJobs.reverse.toSeq + val failedJobs = listener.failedJobs.reverse.toSeq + val now = System.currentTimeMillis + + val activeJobsTable = + jobsTable(activeJobs.sortBy(_.startTime.getOrElse(-1L)).reverse) + val completedJobsTable = + jobsTable(completedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + val failedJobsTable = + jobsTable(failedJobs.sortBy(_.endTime.getOrElse(-1L)).reverse) + + val summary: NodeSeq = +
    +
      + {if (startTime.isDefined) { + // Total duration is not meaningful unless the UI is live +
    • + Total Duration: + {UIUtils.formatDuration(now - startTime.get)} +
    • + }} +
    • + Scheduling Mode: + {listener.schedulingMode.map(_.toString).getOrElse("Unknown")} +
    • +
    • + Active Jobs: + {activeJobs.size} +
    • +
    • + Completed Jobs: + {completedJobs.size} +
    • +
    • + Failed Jobs: + {failedJobs.size} +
    • +
    +
    + + val content = summary ++ +

    Active Jobs ({activeJobs.size})

    ++ activeJobsTable ++ +

    Completed Jobs ({completedJobs.size})

    ++ completedJobsTable ++ +

    Failed Jobs ({failedJobs.size})

    ++ failedJobsTable + + val helpText = """A job is triggered by a action, like "count()" or "saveAsTextFile()".""" + + " Click on a job's title to see information about the stages of tasks associated with" + + " the job." + + UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText)) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala new file mode 100644 index 0000000000000..b0f8ca2ab0d3f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import javax.servlet.http.HttpServletRequest + +import scala.xml.{Node, NodeSeq} + +import org.apache.spark.scheduler.Schedulable +import org.apache.spark.ui.{WebUIPage, UIUtils} + +/** Page showing list of all ongoing and recently finished stages and pools */ +private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { + private val sc = parent.sc + private val listener = parent.listener + private def isFairScheduler = parent.isFairScheduler + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val activeStages = listener.activeStages.values.toSeq + val completedStages = listener.completedStages.reverse.toSeq + val numCompletedStages = listener.numCompletedStages + val failedStages = listener.failedStages.reverse.toSeq + val numFailedStages = listener.numFailedStages + val now = System.currentTimeMillis + + val activeStagesTable = + new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) + val completedStagesTable = + new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false) + val failedStagesTable = + new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler) + + // For now, pool information is only accessible in live UIs + val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable]) + val poolTable = new PoolTable(pools, parent) + + val summary: NodeSeq = +
    +
      + {if (sc.isDefined) { + // Total duration is not meaningful unless the UI is live +
    • + Total Duration: + {UIUtils.formatDuration(now - sc.get.startTime)} +
    • + }} +
    • + Scheduling Mode: + {listener.schedulingMode.map(_.toString).getOrElse("Unknown")} +
    • +
    • + Active Stages: + {activeStages.size} +
    • +
    • + Completed Stages: + {numCompletedStages} +
    • +
    • + Failed Stages: + {numFailedStages} +
    • +
    +
    + + val content = summary ++ + {if (sc.isDefined && isFairScheduler) { +

    {pools.size} Fair Scheduler Pools

    ++ poolTable.toNodeSeq + } else { + Seq[Node]() + }} ++ +

    Active Stages ({activeStages.size})

    ++ + activeStagesTable.toNodeSeq ++ +

    Completed Stages ({numCompletedStages})

    ++ + completedStagesTable.toNodeSeq ++ +

    Failed Stages ({numFailedStages})

    ++ + failedStagesTable.toNodeSeq + + UIUtils.headerSparkPage("Spark Stages (for all jobs)", content, parent) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 2987dc04494a5..9836d11a6d85f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -25,7 +25,7 @@ import org.apache.spark.ui.jobs.UIData.StageUIData import org.apache.spark.util.Utils /** Stage summary grouped by executors. */ -private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobProgressTab) { +private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: StagesTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { @@ -36,7 +36,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr /** Special table which merges two header cells. */ private def executorTable[T](): Seq[Node] = { - +
    @@ -45,6 +45,7 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr + @@ -71,19 +72,21 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobPr - + - - + - - - } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala new file mode 100644 index 0000000000000..77d36209c6048 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import scala.collection.mutable +import scala.xml.{NodeSeq, Node} + +import javax.servlet.http.HttpServletRequest + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.scheduler.StageInfo +import org.apache.spark.ui.{UIUtils, WebUIPage} + +/** Page showing statistics and stage list for a given job */ +private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { + private val listener = parent.listener + + def render(request: HttpServletRequest): Seq[Node] = { + listener.synchronized { + val jobId = request.getParameter("id").toInt + val jobDataOption = listener.jobIdToData.get(jobId) + if (jobDataOption.isEmpty) { + val content = +
    +

    No information to display for job {jobId}

    +
    + return UIUtils.headerSparkPage( + s"Details for Job $jobId", content, parent) + } + val jobData = jobDataOption.get + val isComplete = jobData.status != JobExecutionStatus.RUNNING + val stages = jobData.stageIds.map { stageId => + // This could be empty if the JobProgressListener hasn't received information about the + // stage or if the stage information has been garbage collected + listener.stageIdToInfo.getOrElse(stageId, + new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, "Unknown")) + } + + val activeStages = mutable.Buffer[StageInfo]() + val completedStages = mutable.Buffer[StageInfo]() + // If the job is completed, then any pending stages are displayed as "skipped": + val pendingOrSkippedStages = mutable.Buffer[StageInfo]() + val failedStages = mutable.Buffer[StageInfo]() + for (stage <- stages) { + if (stage.submissionTime.isEmpty) { + pendingOrSkippedStages += stage + } else if (stage.completionTime.isDefined) { + if (stage.failureReason.isDefined) { + failedStages += stage + } else { + completedStages += stage + } + } else { + activeStages += stage + } + } + + val activeStagesTable = + new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) + val pendingOrSkippedStagesTable = + new StageTableBase(pendingOrSkippedStages.sortBy(_.stageId).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = false) + val completedStagesTable = + new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false) + val failedStagesTable = + new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath, + parent.listener, isFairScheduler = parent.isFairScheduler) + + val shouldShowActiveStages = activeStages.nonEmpty + val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowCompletedStages = completedStages.nonEmpty + val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty + val shouldShowFailedStages = failedStages.nonEmpty + + val summary: NodeSeq = +
    +
      +
    • + Status: + {jobData.status} +
    • + { + if (jobData.jobGroup.isDefined) { +
    • + Job Group: + {jobData.jobGroup.get} +
    • + } + } + { + if (shouldShowActiveStages) { +
    • + Active Stages: + {activeStages.size} +
    • + } + } + { + if (shouldShowPendingStages) { +
    • + + Pending Stages: + {pendingOrSkippedStages.size} +
    • + } + } + { + if (shouldShowCompletedStages) { +
    • + Completed Stages: + {completedStages.size} +
    • + } + } + { + if (shouldShowSkippedStages) { +
    • + Skipped Stages: + {pendingOrSkippedStages.size} +
    • + } + } + { + if (shouldShowFailedStages) { +
    • + Failed Stages: + {failedStages.size} +
    • + } + } +
    +
    + + var content = summary + if (shouldShowActiveStages) { + content ++=

    Active Stages ({activeStages.size})

    ++ + activeStagesTable.toNodeSeq + } + if (shouldShowPendingStages) { + content ++=

    Pending Stages ({pendingOrSkippedStages.size})

    ++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowCompletedStages) { + content ++=

    Completed Stages ({completedStages.size})

    ++ + completedStagesTable.toNodeSeq + } + if (shouldShowSkippedStages) { + content ++=

    Skipped Stages ({pendingOrSkippedStages.size})

    ++ + pendingOrSkippedStagesTable.toNodeSeq + } + if (shouldShowFailedStages) { + content ++=

    Failed Stages ({failedStages.size})

    ++ + failedStagesTable.toNodeSeq + } + UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index eaeb861f59e5a..72935beb3a34a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import scala.collection.mutable.{HashMap, ListBuffer} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi @@ -40,29 +40,182 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { import JobProgressListener._ - // How many stages to remember - val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES) + // Define a handful of type aliases so that data structures' types can serve as documentation. + // These type aliases are public because they're used in the types of public fields: - // Map from stageId to StageInfo - val activeStages = new HashMap[Int, StageInfo] + type JobId = Int + type StageId = Int + type StageAttemptId = Int + type PoolName = String + type ExecutorId = String - // Map from (stageId, attemptId) to StageUIData - val stageIdToData = new HashMap[(Int, Int), StageUIData] + // Jobs: + val activeJobs = new HashMap[JobId, JobUIData] + val completedJobs = ListBuffer[JobUIData]() + val failedJobs = ListBuffer[JobUIData]() + val jobIdToData = new HashMap[JobId, JobUIData] + // Stages: + val activeStages = new HashMap[StageId, StageInfo] val completedStages = ListBuffer[StageInfo]() + val skippedStages = ListBuffer[StageInfo]() val failedStages = ListBuffer[StageInfo]() + val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData] + val stageIdToInfo = new HashMap[StageId, StageInfo] + val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]] + val poolToActiveStages = HashMap[PoolName, HashMap[StageId, StageInfo]]() + // Total of completed and failed stages that have ever been run. These may be greater than + // `completedStages.size` and `failedStages.size` if we have run more stages or jobs than + // JobProgressListener's retention limits. + var numCompletedStages = 0 + var numFailedStages = 0 + + // Misc: + val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]() + def blockManagerIds = executorIdToBlockManagerId.values.toSeq - // Map from pool name to a hash map (map from stage id to StageInfo). - val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]() + var schedulingMode: Option[SchedulingMode] = None - val executorIdToBlockManagerId = HashMap[String, BlockManagerId]() + // To limit the total memory usage of JobProgressListener, we only track information for a fixed + // number of non-active jobs and stages (there is no limit for active jobs and stages): - var schedulingMode: Option[SchedulingMode] = None + val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES) + val retainedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS) + + // We can test for memory leaks by ensuring that collections that track non-active jobs and + // stages do not grow without bound and that collections for active jobs/stages eventually become + // empty once Spark is idle. Let's partition our collections into ones that should be empty + // once Spark is idle and ones that should have a hard- or soft-limited sizes. + // These methods are used by unit tests, but they're defined here so that people don't forget to + // update the tests when adding new collections. Some collections have multiple levels of + // nesting, etc, so this lets us customize our notion of "size" for each structure: + + // These collections should all be empty once Spark is idle (no active stages / jobs): + private[spark] def getSizesOfActiveStateTrackingCollections: Map[String, Int] = { + Map( + "activeStages" -> activeStages.size, + "activeJobs" -> activeJobs.size, + "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum, + "stageIdToActiveJobIds" -> stageIdToActiveJobIds.values.map(_.size).sum + ) + } - def blockManagerIds = executorIdToBlockManagerId.values.toSeq + // These collections should stop growing once we have run at least `spark.ui.retainedStages` + // stages and `spark.ui.retainedJobs` jobs: + private[spark] def getSizesOfHardSizeLimitedCollections: Map[String, Int] = { + Map( + "completedJobs" -> completedJobs.size, + "failedJobs" -> failedJobs.size, + "completedStages" -> completedStages.size, + "skippedStages" -> skippedStages.size, + "failedStages" -> failedStages.size + ) + } + + // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to + // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings: + private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = { + Map( + "jobIdToData" -> jobIdToData.size, + "stageIdToData" -> stageIdToData.size, + "stageIdToStageInfo" -> stageIdToInfo.size + ) + } + + /** If stages is too large, remove and garbage collect old stages */ + private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { + if (stages.size > retainedStages) { + val toRemove = math.max(retainedStages / 10, 1) + stages.take(toRemove).foreach { s => + stageIdToData.remove((s.stageId, s.attemptId)) + stageIdToInfo.remove(s.stageId) + } + stages.trimStart(toRemove) + } + } + + /** If jobs is too large, remove and garbage collect old jobs */ + private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized { + if (jobs.size > retainedJobs) { + val toRemove = math.max(retainedJobs / 10, 1) + jobs.take(toRemove).foreach { job => + jobIdToData.remove(job.jobId) + } + jobs.trimStart(toRemove) + } + } + + override def onJobStart(jobStart: SparkListenerJobStart) = synchronized { + val jobGroup = for ( + props <- Option(jobStart.properties); + group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) + ) yield group + val jobData: JobUIData = + new JobUIData( + jobId = jobStart.jobId, + startTime = Some(System.currentTimeMillis), + endTime = None, + stageIds = jobStart.stageIds, + jobGroup = jobGroup, + status = JobExecutionStatus.RUNNING) + // Compute (a potential underestimate of) the number of tasks that will be run by this job. + // This may be an underestimate because the job start event references all of the result + // stages's transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + jobData.numTasks = { + val allStages = jobStart.stageInfos + val missingStages = allStages.filter(_.completionTime.isEmpty) + missingStages.map(_.numTasks).sum + } + jobIdToData(jobStart.jobId) = jobData + activeJobs(jobStart.jobId) = jobData + for (stageId <- jobStart.stageIds) { + stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId) + } + // If there's no information for a stage, store the StageInfo received from the scheduler + // so that we can display stage descriptions for pending stages: + for (stageInfo <- jobStart.stageInfos) { + stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo) + stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptId), new StageUIData) + } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { + val jobData = activeJobs.remove(jobEnd.jobId).getOrElse { + logWarning(s"Job completed for unknown job ${jobEnd.jobId}") + new JobUIData(jobId = jobEnd.jobId) + } + jobData.endTime = Some(System.currentTimeMillis()) + jobEnd.jobResult match { + case JobSucceeded => + completedJobs += jobData + trimJobsIfNecessary(completedJobs) + jobData.status = JobExecutionStatus.SUCCEEDED + case JobFailed(exception) => + failedJobs += jobData + trimJobsIfNecessary(failedJobs) + jobData.status = JobExecutionStatus.FAILED + } + for (stageId <- jobData.stageIds) { + stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => + jobsUsingStage.remove(jobEnd.jobId) + stageIdToInfo.get(stageId).foreach { stageInfo => + if (stageInfo.submissionTime.isEmpty) { + // if this stage is pending, it won't complete, so mark it as "skipped": + skippedStages += stageInfo + trimStagesIfNecessary(skippedStages) + jobData.numSkippedStages += 1 + jobData.numSkippedTasks += stageInfo.numTasks + } + } + } + } + } override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized { val stage = stageCompleted.stageInfo + stageIdToInfo(stage.stageId) = stage val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), { logWarning("Stage completed for unknown stage " + stage.stageId) new StageUIData @@ -78,19 +231,25 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { activeStages.remove(stage.stageId) if (stage.failureReason.isEmpty) { completedStages += stage - trimIfNecessary(completedStages) + numCompletedStages += 1 + trimStagesIfNecessary(completedStages) } else { failedStages += stage - trimIfNecessary(failedStages) + numFailedStages += 1 + trimStagesIfNecessary(failedStages) } - } - /** If stages is too large, remove and garbage collect old stages */ - private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { - if (stages.size > retainedStages) { - val toRemove = math.max(retainedStages / 10, 1) - stages.take(toRemove).foreach { s => stageIdToData.remove((s.stageId, s.attemptId)) } - stages.trimStart(toRemove) + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages -= 1 + if (stage.failureReason.isEmpty) { + jobData.completedStageIndices.add(stage.stageId) + } else { + jobData.numFailedStages += 1 + } } } @@ -103,6 +262,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME) }.getOrElse(DEFAULT_POOL_NAME) + stageIdToInfo(stage.stageId) = stage val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData) stageData.schedulingPool = poolName @@ -112,6 +272,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]) stages(stage.stageId) = stage + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveStages += 1 + } } override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { @@ -124,6 +292,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.numActiveTasks += 1 stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo)) } + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks += 1 + } } override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { @@ -181,6 +356,20 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { taskData.taskInfo = info taskData.taskMetrics = metrics taskData.errorMessage = errorMessage + + for ( + activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId); + jobId <- activeJobsDependentOnStage; + jobData <- jobIdToData.get(jobId) + ) { + jobData.numActiveTasks -= 1 + taskEnd.reason match { + case Success => + jobData.numCompletedTasks += 1 + case _ => + jobData.numFailedTasks += 1 + } + } } } @@ -214,6 +403,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { stageData.inputBytes += inputBytesDelta execSummary.inputBytes += inputBytesDelta + val outputBytesDelta = + (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L) + - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L)) + stageData.outputBytes += outputBytesDelta + execSummary.outputBytes += outputBytesDelta + val diskSpillDelta = taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L) stageData.diskBytesSpilled += diskSpillDelta @@ -277,4 +472,5 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { private object JobProgressListener { val DEFAULT_POOL_NAME = "default" val DEFAULT_RETAINED_STAGES = 1000 + val DEFAULT_RETAINED_JOBS = 1000 } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala deleted file mode 100644 index a82f71ed08475..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import javax.servlet.http.HttpServletRequest - -import scala.xml.{Node, NodeSeq} - -import org.apache.spark.scheduler.Schedulable -import org.apache.spark.ui.{WebUIPage, UIUtils} - -/** Page showing list of all ongoing and recently finished stages and pools */ -private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { - private val live = parent.live - private val sc = parent.sc - private val listener = parent.listener - private lazy val isFairScheduler = parent.isFairScheduler - - def render(request: HttpServletRequest): Seq[Node] = { - listener.synchronized { - val activeStages = listener.activeStages.values.toSeq - val completedStages = listener.completedStages.reverse.toSeq - val failedStages = listener.failedStages.reverse.toSeq - val now = System.currentTimeMillis - - val activeStagesTable = - new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, - parent, parent.killEnabled) - val completedStagesTable = - new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent) - val failedStagesTable = - new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent) - - // For now, pool information is only accessible in live UIs - val pools = if (live) sc.getAllPools else Seq[Schedulable]() - val poolTable = new PoolTable(pools, parent) - - val summary: NodeSeq = -
    -
      - {if (live) { - // Total duration is not meaningful unless the UI is live -
    • - Total Duration: - {UIUtils.formatDuration(now - sc.startTime)} -
    • - }} -
    • - Scheduling Mode: - {listener.schedulingMode.map(_.toString).getOrElse("Unknown")} -
    • -
    • - Active Stages: - {activeStages.size} -
    • -
    • - Completed Stages: - {completedStages.size} -
    • -
    • - Failed Stages: - {failedStages.size} -
    • -
    -
    - - val content = summary ++ - {if (live && isFairScheduler) { -

    {pools.size} Fair Scheduler Pools

    ++ poolTable.toNodeSeq - } else { - Seq[Node]() - }} ++ -

    Active Stages ({activeStages.size})

    ++ - activeStagesTable.toNodeSeq ++ -

    Completed Stages ({completedStages.size})

    ++ - completedStagesTable.toNodeSeq ++ -

    Failed Stages ({failedStages.size})

    ++ - failedStagesTable.toNodeSeq - - UIUtils.headerSparkPage("Spark Stages", content, parent) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala deleted file mode 100644 index c16542c9db30f..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import javax.servlet.http.HttpServletRequest - -import org.apache.spark.SparkConf -import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.{SparkUI, SparkUITab} - -/** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") { - val live = parent.live - val sc = parent.sc - val conf = if (live) sc.conf else new SparkConf - val killEnabled = conf.getBoolean("spark.ui.killEnabled", true) - val listener = new JobProgressListener(conf) - - attachPage(new JobProgressPage(this)) - attachPage(new StagePage(this)) - attachPage(new PoolPage(this)) - parent.registerListener(listener) - - def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) - - def handleKillRequest(request: HttpServletRequest) = { - if ((killEnabled) && (parent.securityManager.checkModifyPermissions(request.getRemoteUser))) { - val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean - val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt - if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { - sc.cancelStage(stageId) - } - // Do a quick pause here to give Spark time to kill the stage so it shows up as - // killed after the refresh. Note that this will block the serving thread so the - // time should be limited in duration. - Thread.sleep(100) - } - } - -} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala new file mode 100644 index 0000000000000..b2bbfdee56946 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import org.apache.spark.scheduler.SchedulingMode +import org.apache.spark.ui.{SparkUI, SparkUITab} + +/** Web UI showing progress status of all jobs in the given SparkContext. */ +private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") { + val sc = parent.sc + val killEnabled = parent.killEnabled + def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + val listener = parent.jobProgressListener + + attachPage(new AllJobsPage(this)) + attachPage(new JobPage(this)) +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index 7a6c7d1a497ed..5fc6cc7533150 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -25,8 +25,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing specific pool details */ -private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { - private val live = parent.live +private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { private val sc = parent.sc private val listener = parent.listener @@ -38,11 +37,12 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { case Some(s) => s.values.toSeq case None => Seq[StageInfo]() } - val activeStagesTable = - new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent) + val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, + parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler, + killEnabled = parent.killEnabled) // For now, pool information is only accessible in live UIs - val pools = if (live) Seq(sc.getPoolForName(poolName).get) else Seq[Schedulable]() + val pools = sc.map(_.getPoolForName(poolName).get).toSeq val poolTable = new PoolTable(pools, parent) val content = diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index 64178e1e33d41..df1899e7a9b84 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -24,7 +24,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.UIUtils /** Table showing list of pools */ -private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { +private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) { private val listener = parent.listener def toNodeSeq: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index db01be596e073..bfa54f8492068 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -22,13 +22,16 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, Unparsed} +import org.apache.commons.lang3.StringEscapeUtils + +import org.apache.spark.executor.TaskMetrics import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} import org.apache.spark.ui.jobs.UIData._ import org.apache.spark.util.{Utils, Distribution} -import org.apache.spark.scheduler.AccumulableInfo +import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} /** Page showing statistics and task list for a given stage */ -private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { +private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { @@ -52,12 +55,13 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val numCompleted = tasks.count(_.taskInfo.finished) val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables + val hasAccumulators = accumulables.size > 0 val hasInput = stageData.inputBytes > 0 + val hasOutput = stageData.outputBytes > 0 val hasShuffleRead = stageData.shuffleReadBytes > 0 val hasShuffleWrite = stageData.shuffleWriteBytes > 0 val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0 - // scalastyle:off val summary =
      @@ -65,55 +69,125 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { Total task time across all tasks: {UIUtils.formatDuration(stageData.executorRunTime)} - {if (hasInput) + {if (hasInput) {
    • Input: {Utils.bytesToString(stageData.inputBytes)}
    • - } - {if (hasShuffleRead) + }} + {if (hasOutput) { +
    • + Output: + {Utils.bytesToString(stageData.outputBytes)} +
    • + }} + {if (hasShuffleRead) {
    • Shuffle read: {Utils.bytesToString(stageData.shuffleReadBytes)}
    • - } - {if (hasShuffleWrite) + }} + {if (hasShuffleWrite) {
    • Shuffle write: {Utils.bytesToString(stageData.shuffleWriteBytes)}
    • - } - {if (hasBytesSpilled) -
    • - Shuffle spill (memory): - {Utils.bytesToString(stageData.memoryBytesSpilled)} -
    • -
    • - Shuffle spill (disk): - {Utils.bytesToString(stageData.diskBytesSpilled)} -
    • - } + }} + {if (hasBytesSpilled) { +
    • + Shuffle spill (memory): + {Utils.bytesToString(stageData.memoryBytesSpilled)} +
    • +
    • + Shuffle spill (disk): + {Utils.bytesToString(stageData.diskBytesSpilled)} +
    • + }}
    - // scalastyle:on + + val showAdditionalMetrics = +
    + + + Show additional metrics + + +
    + val accumulableHeaders: Seq[String] = Seq("Accumulable", "Value") def accumulableRow(acc: AccumulableInfo) = val accumulableTable = UIUtils.listingTable(accumulableHeaders, accumulableRow, accumulables.values.toSeq) - val taskHeaders: Seq[String] = + val taskHeadersAndCssClasses: Seq[(String, String)] = Seq( - "Index", "ID", "Attempt", "Status", "Locality Level", "Executor", - "Launch Time", "Duration", "GC Time", "Accumulators") ++ - {if (hasInput) Seq("Input") else Nil} ++ - {if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++ - {if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++ - {if (hasBytesSpilled) Seq("Shuffle Spill (Memory)", "Shuffle Spill (Disk)") else Nil} ++ - Seq("Errors") + ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), + ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + ("GC Time", TaskDetailsClassNames.GC_TIME), + ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput) Seq(("Input", "")) else Nil} ++ + {if (hasOutput) Seq(("Output", "")) else Nil} ++ + {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++ + {if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++ + {if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + else Nil} ++ + Seq(("Errors", "")) + + val unzipped = taskHeadersAndCssClasses.unzip val taskTable = UIUtils.listingTable( - taskHeaders, taskRow(hasInput, hasShuffleRead, hasShuffleWrite, hasBytesSpilled), tasks) - + unzipped._1, + taskRow(hasAccumulators, hasInput, hasOutput, hasShuffleRead, hasShuffleWrite, + hasBytesSpilled), + tasks, + headerClasses = unzipped._2) // Excludes tasks which failed and have incomplete metrics val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) @@ -122,18 +196,48 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { None } else { - val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => - metrics.get.resultSerializationTime.toDouble + def getFormattedTimeQuantiles(times: Seq[Double]): Seq[Node] = { + Distribution(times).get.getQuantiles().map { millis => + + } } - val serializationQuantiles = - +: Distribution(serializationTimes). - get.getQuantiles().map(ms => ) + + val deserializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.executorDeserializeTime.toDouble + } + val deserializationQuantiles = + +: getFormattedTimeQuantiles(deserializationTimes) val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorRunTime.toDouble } - val serviceQuantiles = +: Distribution(serviceTimes).get.getQuantiles() - .map(ms => ) + val serviceQuantiles = +: getFormattedTimeQuantiles(serviceTimes) + + val gcTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.jvmGCTime.toDouble + } + val gcQuantiles = + +: getFormattedTimeQuantiles(gcTimes) + + val serializationTimes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.resultSerializationTime.toDouble + } + val serializationQuantiles = + +: getFormattedTimeQuantiles(serializationTimes) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => if (info.gettingResultTime > 0) { @@ -142,76 +246,91 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { 0.0 } } - val gettingResultQuantiles = +: - Distribution(gettingResultTimes).get.getQuantiles().map { millis => - - } + val gettingResultQuantiles = + +: + getFormattedTimeQuantiles(gettingResultTimes) // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) => - val totalExecutionTime = { - if (info.gettingResultTime > 0) { - (info.gettingResultTime - info.launchTime).toDouble - } else { - (info.finishTime - info.launchTime).toDouble - } - } - totalExecutionTime - metrics.get.executorRunTime + getSchedulerDelay(info, metrics.get).toDouble } val schedulerDelayTitle = + title={ToolTips.SCHEDULER_DELAY} data-placement="right">Scheduler Delay val schedulerDelayQuantiles = schedulerDelayTitle +: - Distribution(schedulerDelays).get.getQuantiles().map { millis => - - } + getFormattedTimeQuantiles(schedulerDelays) - def getQuantileCols(data: Seq[Double]) = + def getFormattedSizeQuantiles(data: Seq[Double]) = Distribution(data).get.getQuantiles().map(d => ) val inputSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble } - val inputQuantiles = +: getQuantileCols(inputSizes) + val inputQuantiles = +: getFormattedSizeQuantiles(inputSizes) + + val outputSizes = validTasks.map { case TaskUIData(_, metrics, _) => + metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + } + val outputQuantiles = +: getFormattedSizeQuantiles(outputSizes) val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble } val shuffleReadQuantiles = +: - getQuantileCols(shuffleReadSizes) + getFormattedSizeQuantiles(shuffleReadSizes) val shuffleWriteSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.shuffleWriteMetrics.map(_.shuffleBytesWritten).getOrElse(0L).toDouble } - val shuffleWriteQuantiles = +: getQuantileCols(shuffleWriteSizes) + val shuffleWriteQuantiles = +: + getFormattedSizeQuantiles(shuffleWriteSizes) val memoryBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.memoryBytesSpilled.toDouble } val memoryBytesSpilledQuantiles = +: - getQuantileCols(memoryBytesSpilledSizes) + getFormattedSizeQuantiles(memoryBytesSpilledSizes) val diskBytesSpilledSizes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.diskBytesSpilled.toDouble } val diskBytesSpilledQuantiles = +: - getQuantileCols(diskBytesSpilledSizes) + getFormattedSizeQuantiles(diskBytesSpilledSizes) val listings: Seq[Seq[Node]] = Seq( - serializationQuantiles, - serviceQuantiles, - gettingResultQuantiles, - schedulerDelayQuantiles, - if (hasInput) inputQuantiles else Nil, - if (hasShuffleRead) shuffleReadQuantiles else Nil, - if (hasShuffleWrite) shuffleWriteQuantiles else Nil, - if (hasBytesSpilled) memoryBytesSpilledQuantiles else Nil, - if (hasBytesSpilled) diskBytesSpilledQuantiles else Nil) + {serviceQuantiles}, + {schedulerDelayQuantiles}, + + {deserializationQuantiles} + + {gcQuantiles}, + + {serializationQuantiles} + , + {gettingResultQuantiles}, + if (hasInput) {inputQuantiles} else Nil, + if (hasOutput) {outputQuantiles} else Nil, + if (hasShuffleRead) {shuffleReadQuantiles} else Nil, + if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil, + if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil, + if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil) val quantileHeaders = Seq("Metric", "Min", "25th percentile", "Median", "75th percentile", "Max") - def quantileRow(data: Seq[Node]): Seq[Node] = {data} - Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true)) + // The summary table does not use CSS to stripe rows, which doesn't work with hidden + // rows (instead, JavaScript in table.js is used to stripe the non-hidden rows). + Some(UIUtils.listingTable( + quantileHeaders, + identity[Seq[Node]], + listings, + fixedWidth = true, + id = Some("task-summary-table"), + stripeRowsWithCss = false)) } val executorTable = new ExecutorTable(stageId, stageAttemptId, parent) @@ -221,6 +340,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { val content = summary ++ + showAdditionalMetrics ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ @@ -232,7 +352,9 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } def taskRow( + hasAccumulators: Boolean, hasInput: Boolean, + hasOutput: Boolean, hasShuffleRead: Boolean, hasShuffleWrite: Boolean, hasBytesSpilled: Boolean)(taskData: TaskUIData): Seq[Node] = { @@ -241,8 +363,14 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { else metrics.map(_.executorRunTime).getOrElse(1L) val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") + val schedulerDelay = metrics.map(getSchedulerDelay(info, _)).getOrElse(0L) val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) + val gettingResultTime = info.gettingResultTime + + val maybeAccumulators = info.accumulables + val accumulatorsReadable = maybeAccumulators.map{acc => s"${acc.name}: ${acc.update.get}"} val maybeInput = metrics.flatMap(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") @@ -250,6 +378,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") .getOrElse("") + val maybeOutput = metrics.flatMap(_.outputMetrics) + val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") + val outputReadable = maybeOutput + .map(m => s"${Utils.bytesToString(m.bytesWritten)}") + .getOrElse("") + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead) val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("") @@ -282,30 +416,45 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { } - + - + + - - + {if (hasAccumulators) { + + }} {if (hasInput) { }} + {if (hasOutput) { + + }} {if (hasShuffleRead) { }} - + {errorMessageCell(errorMessage)} } } + + private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { + val error = errorMessage.getOrElse("") + val isMultiline = error.indexOf('\n') >= 0 + // Display the first line by default + val errorSummary = StringEscapeUtils.escapeHtml4( + if (isMultiline) { + error.substring(0, error.indexOf('\n')) + } else { + error + }) + val details = if (isMultiline) { + // scalastyle:off + + +details + ++ + + // scalastyle:on + } else { + "" + } + + } + + private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = { + val totalExecutionTime = { + if (info.gettingResultTime > 0) { + (info.gettingResultTime - info.launchTime) + } else { + (info.finishTime - info.launchTime) + } + } + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + totalExecutionTime - metrics.executorRunTime - executorOverhead + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index 2e67310594784..e7d6244dcd679 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -22,6 +22,8 @@ import scala.xml.Text import java.util.Date +import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.scheduler.StageInfo import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.util.Utils @@ -29,11 +31,10 @@ import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished stages */ private[ui] class StageTableBase( stages: Seq[StageInfo], - parent: JobProgressTab, - killEnabled: Boolean = false) { - - private val listener = parent.listener - protected def isFairScheduler = parent.isFairScheduler + basePath: String, + listener: JobProgressListener, + isFairScheduler: Boolean, + killEnabled: Boolean) { protected def columns: Seq[Node] = { ++ @@ -43,6 +44,7 @@ private[ui] class StageTableBase( + - +
    Executor ID AddressFailed Tasks Succeeded Tasks InputOutput Shuffle Read Shuffle Write Shuffle Spill (Memory)
    {k} {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")}{UIUtils.formatDuration(v.taskTime)}{UIUtils.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} + {Utils.bytesToString(v.inputBytes)} + + {Utils.bytesToString(v.outputBytes)} {Utils.bytesToString(v.shuffleRead)} + {Utils.bytesToString(v.shuffleWrite)} + {Utils.bytesToString(v.memoryBytesSpilled)} + {Utils.bytesToString(v.diskBytesSpilled)}
    {acc.name}{acc.value}
    {UIUtils.formatDuration(millis.toLong)}Result serialization time{UIUtils.formatDuration(ms.toLong)} + + Task Deserialization Time + + Duration{UIUtils.formatDuration(ms.toLong)}Duration + GC Time + + + + Result Serialization Time + + Time spent fetching task results{UIUtils.formatDuration(millis.toLong)} + + Getting Result Time + + Scheduler delay{UIUtils.formatDuration(millis.toLong)}{Utils.bytesToString(d.toLong)}InputInputOutputShuffle Read (Remote)Shuffle WriteShuffle WriteShuffle spill (memory)Shuffle spill (disk)
    {info.status} {info.taskLocality}{info.host}{info.executorId} / {info.host} {UIUtils.formatDate(new Date(info.launchTime))} {formatDuration} + + {UIUtils.formatDuration(schedulerDelay.toLong)} + + {UIUtils.formatDuration(taskDeserializationTime.toLong)} + {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} - {Unparsed( - info.accumulables.map{acc => s"${acc.name}: ${acc.update.get}"}.mkString("
    ") - )} +
    + {UIUtils.formatDuration(serializationTime)} + {Unparsed(accumulatorsReadable.mkString("
    "))} +
    {inputReadable} + {outputReadable} + {shuffleReadReadable} @@ -327,10 +476,47 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { {diskBytesSpilledReadable} - {errorMessage.map { e =>
    {e}
    }.getOrElse("")} -
    {errorSummary}{details}Stage IdDuration Tasks: Succeeded/Total InputOutput Shuffle Read - - - - - - - - - - - - - - - build.dir - ${user.dir}/build - - - - build.dir.hive - ${build.dir}/hive - - - - hadoop.tmp.dir - ${build.dir.hive}/test/hadoop-${user.name} - A base for other temporary directories. - - - - - - hive.exec.scratchdir - ${build.dir}/scratchdir - Scratch space for Hive jobs - - - - hive.exec.local.scratchdir - ${build.dir}/localscratchdir/ - Local scratch space for Hive jobs - - - - javax.jdo.option.ConnectionURL - - jdbc:derby:;databaseName=../build/test/junit_metastore_db;create=true - - - - javax.jdo.option.ConnectionDriverName - org.apache.derby.jdbc.EmbeddedDriver - - - - javax.jdo.option.ConnectionUserName - APP - - - - javax.jdo.option.ConnectionPassword - mine - - - - - hive.metastore.warehouse.dir - ${test.warehouse.dir} - - - - - hive.metastore.metadb.dir - ${build.dir}/test/data/metadb/ - - Required by metastore server or if the uris argument below is not supplied - - - - - test.log.dir - ${build.dir}/test/logs - - - - - test.src.dir - ${build.dir}/src/test - - - - - - - hive.jar.path - ${build.dir.hive}/ql/hive-exec-${version}.jar - - - - - hive.metastore.rawstore.impl - org.apache.hadoop.hive.metastore.ObjectStore - Name of the class that implements org.apache.hadoop.hive.metastore.rawstore interface. This class is used to store and retrieval of raw metadata objects such as table, database - - - - hive.querylog.location - ${build.dir}/tmp - Location of the structured hive logs - - - - - - hive.task.progress - false - Track progress of a task - - - - hive.support.concurrency - false - Whether hive supports concurrency or not. A zookeeper instance must be up and running for the default hive lock manager to support read-write locks. - - - - fs.pfile.impl - org.apache.hadoop.fs.ProxyLocalFileSystem - A proxy for local file system used for cross file system testing - - - - hive.exec.mode.local.auto - false - - Let hive determine whether to run in local mode automatically - Disabling this for tests so that minimr is not affected - - - - - hive.auto.convert.join - false - Whether Hive enable the optimization about converting common join into mapjoin based on the input file size - - - - hive.ignore.mapjoin.hint - false - Whether Hive ignores the mapjoin hint - - - - hive.input.format - org.apache.hadoop.hive.ql.io.CombineHiveInputFormat - The default input format, if it is not specified, the system assigns it. It is set to HiveInputFormat for hadoop versions 17, 18 and 19, whereas it is set to CombineHiveInputFormat for hadoop 20. The user can always overwrite it - if there is a bug in CombineHiveInputFormat, it can always be manually set to HiveInputFormat. - - - - hive.default.rcfile.serde - org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe - The default SerDe hive will use for the rcfile format - - - diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh new file mode 100755 index 0000000000000..7473c20d28e09 --- /dev/null +++ b/dev/change-version-to-2.10.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +find . -name 'pom.xml' | grep -v target \ + | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.11|\1_2.10|g' {} diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh new file mode 100755 index 0000000000000..3957a9f3ba258 --- /dev/null +++ b/dev/change-version-to-2.11.sh @@ -0,0 +1,21 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +find . -name 'pom.xml' | grep -v target \ + | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.10|\1_2.11|g' {} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 281e8d4de6d71..e0aca467ac949 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -27,13 +27,20 @@ # Would be nice to add: # - Send output to stderr and have useful logging in stdout -GIT_USERNAME=${GIT_USERNAME:-pwendell} -GIT_PASSWORD=${GIT_PASSWORD:-XXX} +# Note: The following variables must be set before use! +ASF_USERNAME=${ASF_USERNAME:-pwendell} +ASF_PASSWORD=${ASF_PASSWORD:-XXX} GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX} GIT_BRANCH=${GIT_BRANCH:-branch-1.0} -RELEASE_VERSION=${RELEASE_VERSION:-1.0.0} +RELEASE_VERSION=${RELEASE_VERSION:-1.2.0} +NEXT_VERSION=${NEXT_VERSION:-1.2.1} RC_NAME=${RC_NAME:-rc2} -USER_NAME=${USER_NAME:-pwendell} + +M2_REPO=~/.m2/repository +SPARK_REPO=$M2_REPO/org/apache/spark +NEXUS_ROOT=https://repository.apache.org/service/local/staging +NEXUS_UPLOAD=$NEXUS_ROOT/deploy/maven2 +NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads if [ -z "$JAVA_HOME" ]; then echo "Error: JAVA_HOME is not set, cannot proceed." @@ -46,31 +53,90 @@ set -e GIT_TAG=v$RELEASE_VERSION-$RC_NAME if [[ ! "$@" =~ --package-only ]]; then - echo "Creating and publishing release" + echo "Creating release commit and publishing to Apache repository" # Artifact publishing - git clone https://git-wip-us.apache.org/repos/asf/spark.git -b $GIT_BRANCH - cd spark + git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \ + -b $GIT_BRANCH + pushd spark export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g" - mvn -Pyarn release:clean - - mvn -DskipTests \ - -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ - -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ - -Dmaven.javadoc.skip=true \ - -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ - -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - --batch-mode release:prepare - - mvn -DskipTests \ - -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ - -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Dmaven.javadoc.skip=true \ + # Create release commits and push them to github + # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build + # or before we coin the release commit. This helps avoid races where + # other people add commits to this branch while we are in the middle of building. + old=" ${RELEASE_VERSION}-SNAPSHOT<\/version>" + new=" ${RELEASE_VERSION}<\/version>" + find . -name pom.xml -o -name package.scala | grep -v dev | xargs -I {} sed -i \ + -e "s/$old/$new/" {} + git commit -a -m "Preparing Spark release $GIT_TAG" + echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH" + git tag $GIT_TAG + + old=" ${RELEASE_VERSION}<\/version>" + new=" ${NEXT_VERSION}-SNAPSHOT<\/version>" + find . -name pom.xml -o -name package.scala | grep -v dev | xargs -I {} sed -i \ + -e "s/$old/$new/" {} + git commit -a -m "Preparing development version ${NEXT_VERSION}-SNAPSHOT" + git push origin $GIT_TAG + git push origin HEAD:$GIT_BRANCH + git checkout -f $GIT_TAG + + # Using Nexus API documented here: + # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API + echo "Creating Nexus staging repository" + repo_request="Apache Spark $GIT_TAG" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start) + staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/") + echo "Created Nexus staging repository: $staged_repo_id" + + rm -rf $SPARK_REPO + + mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ - release:perform + clean install - cd .. + ./dev/change-version-to-2.11.sh + + mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ + -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ + clean install + + ./dev/change-version-to-2.10.sh + + pushd $SPARK_REPO + + # Remove any extra files generated during install + find . -type f |grep -v \.jar |grep -v \.pom | xargs rm + + echo "Creating hash and signature files" + for file in $(find . -type f) + do + echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; + gpg --print-md MD5 $file > $file.md5; + gpg --print-md SHA1 $file > $file.sha1 + done + + echo "Uplading files to $NEXUS_UPLOAD" + for file in $(find . -type f) + do + # strip leading ./ + file_short=$(echo $file | sed -e "s/\.\///") + dest_url="$NEXUS_UPLOAD/org/apache/spark/$file_short" + echo " Uploading $file_short" + curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url + done + + echo "Closing nexus staging repository" + repo_request="$staged_repo_idApache Spark $GIT_TAG" + out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \ + -H "Content-Type:application/xml" -v \ + $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish) + echo "Closed Nexus staging repository: $staged_repo_id" + + popd + popd rm -rf spark fi @@ -101,7 +167,13 @@ make_binary_release() { cp -r spark spark-$RELEASE_VERSION-bin-$NAME cd spark-$RELEASE_VERSION-bin-$NAME - ./make-distribution.sh --name $NAME --tgz $FLAGS + + # TODO There should probably be a flag to make-distribution to allow 2.11 support + if [[ $FLAGS == *scala-2.11* ]]; then + ./dev/change-version-to-2.11.sh + fi + + ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log cd .. cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz . rm -rf spark-$RELEASE_VERSION-bin-$NAME @@ -117,22 +189,24 @@ make_binary_release() { spark-$RELEASE_VERSION-bin-$NAME.tgz.sha } -make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" & -make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & -make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" & -make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" & + +make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" & +make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" & +make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" & +make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" & +make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" & +make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" & +make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" & make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" & -make_binary_release "mapr3" "-Pmapr3 -Phive" & -make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" & wait # Copy data echo "Copying release tarballs" rc_folder=spark-$RELEASE_VERSION-$RC_NAME -ssh $USER_NAME@people.apache.org \ - mkdir /home/$USER_NAME/public_html/$rc_folder +ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_folder scp spark-* \ - $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_folder/ + $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/ # Docs cd spark @@ -142,12 +216,12 @@ cd docs JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build echo "Copying release documentation" rc_docs_folder=${rc_folder}-docs -ssh $USER_NAME@people.apache.org \ - mkdir /home/$USER_NAME/public_html/$rc_docs_folder -rsync -r _site/* $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_docs_folder +ssh $ASF_USERNAME@people.apache.org \ + mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder +rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder echo "Release $RELEASE_VERSION completed:" echo "Git tag:\t $GIT_TAG" echo "Release commit:\t $release_hash" -echo "Binary location:\t http://people.apache.org/~$USER_NAME/$rc_folder" -echo "Doc location:\t http://people.apache.org/~$USER_NAME/$rc_docs_folder" +echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder" +echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder" diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py new file mode 100755 index 0000000000000..f4bf734081583 --- /dev/null +++ b/dev/create-release/generate-contributors.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This script automates the process of creating release notes. + +import os +import re +import sys + +from releaseutils import * + +# You must set the following before use! +JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira") +START_COMMIT = os.environ.get("START_COMMIT", "37b100") +END_COMMIT = os.environ.get("END_COMMIT", "3693ae") + +try: + from jira.client import JIRA +except ImportError: + print "This tool requires the jira-python library" + print "Install using 'sudo pip install jira-python'" + sys.exit(-1) + +try: + import unidecode +except ImportError: + print "This tool requires the unidecode library to decode obscure github usernames" + print "Install using 'sudo pip install unidecode'" + sys.exit(-1) + +# If commit range is not specified, prompt the user to provide it +if not START_COMMIT or not END_COMMIT: + print "A commit range is required to proceed." + if not START_COMMIT: + START_COMMIT = raw_input("Please specify starting commit hash (inclusive): ") + if not END_COMMIT: + END_COMMIT = raw_input("Please specify ending commit hash (non-inclusive): ") + +# Verify provided arguments +start_commit_line = get_one_line(START_COMMIT) +end_commit_line = get_one_line(END_COMMIT) +num_commits = num_commits_in_range(START_COMMIT, END_COMMIT) +if not start_commit_line: sys.exit("Start commit %s not found!" % START_COMMIT) +if not end_commit_line: sys.exit("End commit %s not found!" % END_COMMIT) +if num_commits == 0: + sys.exit("There are no commits in the provided range [%s, %s)" % (START_COMMIT, END_COMMIT)) +print "\n==================================================================================" +print "JIRA server: %s" % JIRA_API_BASE +print "Start commit (inclusive): %s" % start_commit_line +print "End commit (non-inclusive): %s" % end_commit_line +print "Number of commits in this range: %s" % num_commits +print +response = raw_input("Is this correct? [Y/n] ") +if response.lower() != "y" and response: + sys.exit("Ok, exiting") +print "==================================================================================\n" + +# Find all commits within this range +print "Gathering commits within range [%s..%s)" % (START_COMMIT, END_COMMIT) +commits = get_one_line_commits(START_COMMIT, END_COMMIT) +if not commits: sys.exit("Error: No commits found within this range!") +commits = commits.split("\n") + +# Filter out special commits +releases = [] +reverts = [] +nojiras = [] +filtered_commits = [] +def is_release(commit): + return re.findall("\[release\]", commit.lower()) or\ + "maven-release-plugin" in commit or "CHANGES.txt" in commit +def has_no_jira(commit): + return not re.findall("SPARK-[0-9]+", commit.upper()) +def is_revert(commit): + return "revert" in commit.lower() +def is_docs(commit): + return re.findall("docs*", commit.lower()) or "programming guide" in commit.lower() +for c in commits: + if not c: continue + elif is_release(c): releases.append(c) + elif is_revert(c): reverts.append(c) + elif is_docs(c): filtered_commits.append(c) # docs may not have JIRA numbers + elif has_no_jira(c): nojiras.append(c) + else: filtered_commits.append(c) + +# Warn against ignored commits +def print_indented(_list): + for x in _list: print " %s" % x +if releases or reverts or nojiras: + print "\n==================================================================================" + if releases: print "Releases (%d)" % len(releases); print_indented(releases) + if reverts: print "Reverts (%d)" % len(reverts); print_indented(reverts) + if nojiras: print "No JIRA (%d)" % len(nojiras); print_indented(nojiras) + print "==================== Warning: the above commits will be ignored ==================\n" +response = raw_input("%d commits left to process. Ok to proceed? [y/N] " % len(filtered_commits)) +if response.lower() != "y": + sys.exit("Ok, exiting.") + +# Keep track of warnings to tell the user at the end +warnings = [] + +# Populate a map that groups issues and components by author +# It takes the form: Author name -> { Contribution type -> Spark components } +# For instance, +# { +# 'Andrew Or': { +# 'bug fixes': ['windows', 'core', 'web ui'], +# 'improvements': ['core'] +# }, +# 'Tathagata Das' : { +# 'bug fixes': ['streaming'] +# 'new feature': ['streaming'] +# } +# } +# +author_info = {} +jira_options = { "server": JIRA_API_BASE } +jira = JIRA(jira_options) +print "\n=========================== Compiling contributor list ===========================" +for commit in filtered_commits: + commit_hash = re.findall("^[a-z0-9]+", commit)[0] + issues = re.findall("SPARK-[0-9]+", commit.upper()) + author = get_author(commit_hash) + author = unidecode.unidecode(unicode(author, "UTF-8")) # guard against special characters + date = get_date(commit_hash) + # Parse components from the commit message, if any + commit_components = find_components(commit, commit_hash) + # Populate or merge an issue into author_info[author] + def populate(issue_type, components): + components = components or [CORE_COMPONENT] # assume core if no components provided + if author not in author_info: + author_info[author] = {} + if issue_type not in author_info[author]: + author_info[author][issue_type] = set() + for component in all_components: + author_info[author][issue_type].add(component) + # Find issues and components associated with this commit + for issue in issues: + jira_issue = jira.issue(issue) + jira_type = jira_issue.fields.issuetype.name + jira_type = translate_issue_type(jira_type, issue, warnings) + jira_components = [translate_component(c.name, commit_hash, warnings)\ + for c in jira_issue.fields.components] + all_components = set(jira_components + commit_components) + populate(jira_type, all_components) + # For docs without an associated JIRA, manually add it ourselves + if is_docs(commit) and not issues: + populate("documentation", commit_components) + print " Processed commit %s authored by %s on %s" % (commit_hash, author, date) +print "==================================================================================\n" + +# Write to contributors file ordered by author names +# Each line takes the format "Author name - semi-colon delimited contributions" +# e.g. Andrew Or - Bug fixes in Windows, Core, and Web UI; improvements in Core +# e.g. Tathagata Das - Bug fixes and new features in Streaming +contributors_file_name = "contributors.txt" +contributors_file = open(contributors_file_name, "w") +authors = author_info.keys() +authors.sort() +for author in authors: + contribution = "" + components = set() + issue_types = set() + for issue_type, comps in author_info[author].items(): + components.update(comps) + issue_types.add(issue_type) + # If there is only one component, mention it only once + # e.g. Bug fixes, improvements in MLlib + if len(components) == 1: + contribution = "%s in %s" % (nice_join(issue_types), next(iter(components))) + # Otherwise, group contributions by issue types instead of modules + # e.g. Bug fixes in MLlib, Core, and Streaming; documentation in YARN + else: + contributions = ["%s in %s" % (issue_type, nice_join(comps)) \ + for issue_type, comps in author_info[author].items()] + contribution = "; ".join(contributions) + # Do not use python's capitalize() on the whole string to preserve case + assert contribution + contribution = contribution[0].capitalize() + contribution[1:] + line = "%s - %s" % (author, contribution) + contributors_file.write(line + "\n") +contributors_file.close() +print "Contributors list is successfully written to %s!" % contributors_file_name + +# Log any warnings encountered in the process +if warnings: + print "\n============ Warnings encountered while creating the contributor list ============" + for w in warnings: print w + print "Please correct these in the final contributors list at %s." % contributors_file_name + print "==================================================================================\n" + diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py new file mode 100755 index 0000000000000..e56d7fa58fa2c --- /dev/null +++ b/dev/create-release/releaseutils.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This file contains helper methods used in creating a release. + +import re +from subprocess import Popen, PIPE + +# Utility functions run git commands (written with Git 1.8.5) +def run_cmd(cmd): return Popen(cmd, stdout=PIPE).communicate()[0] +def get_author(commit_hash): + return run_cmd(["git", "show", "--quiet", "--pretty=format:%an", commit_hash]) +def get_date(commit_hash): + return run_cmd(["git", "show", "--quiet", "--pretty=format:%cd", commit_hash]) +def get_one_line(commit_hash): + return run_cmd(["git", "show", "--quiet", "--pretty=format:\"%h %cd %s\"", commit_hash]) +def get_one_line_commits(start_hash, end_hash): + return run_cmd(["git", "log", "--oneline", "%s..%s" % (start_hash, end_hash)]) +def num_commits_in_range(start_hash, end_hash): + output = run_cmd(["git", "log", "--oneline", "%s..%s" % (start_hash, end_hash)]) + lines = [line for line in output.split("\n") if line] # filter out empty lines + return len(lines) + +# Maintain a mapping for translating issue types to contributions in the release notes +# This serves an additional function of warning the user against unknown issue types +# Note: This list is partially derived from this link: +# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/issuetypes +# Keep these in lower case +known_issue_types = { + "bug": "bug fixes", + "build": "build fixes", + "improvement": "improvements", + "new feature": "new features", + "documentation": "documentation" +} + +# Maintain a mapping for translating component names when creating the release notes +# This serves an additional function of warning the user against unknown components +# Note: This list is largely derived from this link: +# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/components +CORE_COMPONENT = "Core" +known_components = { + "block manager": CORE_COMPONENT, + "build": CORE_COMPONENT, + "deploy": CORE_COMPONENT, + "documentation": CORE_COMPONENT, + "ec2": "EC2", + "examples": CORE_COMPONENT, + "graphx": "GraphX", + "input/output": CORE_COMPONENT, + "java api": "Java API", + "mesos": "Mesos", + "ml": "MLlib", + "mllib": "MLlib", + "project infra": "Project Infra", + "pyspark": "PySpark", + "shuffle": "Shuffle", + "spark core": CORE_COMPONENT, + "spark shell": CORE_COMPONENT, + "sql": "SQL", + "streaming": "Streaming", + "web ui": "Web UI", + "windows": "Windows", + "yarn": "YARN" +} + +# Translate issue types using a format appropriate for writing contributions +# If an unknown issue type is encountered, warn the user +def translate_issue_type(issue_type, issue_id, warnings): + issue_type = issue_type.lower() + if issue_type in known_issue_types: + return known_issue_types[issue_type] + else: + warnings.append("Unknown issue type \"%s\" (see %s)" % (issue_type, issue_id)) + return issue_type + +# Translate component names using a format appropriate for writing contributions +# If an unknown component is encountered, warn the user +def translate_component(component, commit_hash, warnings): + component = component.lower() + if component in known_components: + return known_components[component] + else: + warnings.append("Unknown component \"%s\" (see %s)" % (component, commit_hash)) + return component + +# Parse components in the commit message +# The returned components are already filtered and translated +def find_components(commit, commit_hash): + components = re.findall("\[\w*\]", commit.lower()) + components = [translate_component(c, commit_hash)\ + for c in components if c in known_components] + return components + +# Join a list of strings in a human-readable manner +# e.g. ["Juice"] -> "Juice" +# e.g. ["Juice", "baby"] -> "Juice and baby" +# e.g. ["Juice", "baby", "moon"] -> "Juice, baby, and moon" +def nice_join(str_list): + str_list = list(str_list) # sometimes it's a set + if not str_list: + return "" + elif len(str_list) == 1: + return next(iter(str_list)) + elif len(str_list) == 2: + return " and ".join(str_list) + else: + return ", ".join(str_list[:-1]) + ", and " + str_list[-1] + diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index a8e92e36fe0d8..02ac20984add9 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -73,11 +73,10 @@ def fail(msg): def run_cmd(cmd): + print cmd if isinstance(cmd, list): - print " ".join(cmd) return subprocess.check_output(cmd) else: - print cmd return subprocess.check_output(cmd.split(" ")) diff --git a/dev/run-tests b/dev/run-tests index c3d8f49cdd993..328a73bd8b26d 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -24,6 +24,16 @@ cd "$FWDIR" # Remove work directory rm -rf ./work +source "$FWDIR/dev/run-tests-codes.sh" + +CURRENT_BLOCK=$BLOCK_GENERAL + +function handle_error () { + echo "[error] Got a return code of $? on line $1 of the run-tests script." + exit $CURRENT_BLOCK +} + + # Build against the right verison of Hadoop. { if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then @@ -32,7 +42,7 @@ rm -rf ./work elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.0" ]; then export SBT_MAVEN_PROFILES_ARGS="-Dhadoop.version=2.0.0-mr1-cdh4.1.1" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.2" ]; then - export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Dhadoop.version=2.2.0" + export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0" elif [ "$AMPLAB_JENKINS_BUILD_PROFILE" = "hadoop2.3" ]; then export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi @@ -91,26 +101,34 @@ if [ -n "$AMPLAB_JENKINS" ]; then fi fi -# Fail fast -set -e set -o pipefail +trap 'handle_error $LINENO' ERR echo "" echo "=========================================================================" echo "Running Apache RAT checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_RAT + ./dev/check-license echo "" echo "=========================================================================" echo "Running Scala style checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_SCALA_STYLE + ./dev/lint-scala echo "" echo "=========================================================================" echo "Running Python style checks" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_PYTHON_STYLE + ./dev/lint-python echo "" @@ -118,21 +136,29 @@ echo "=========================================================================" echo "Building Spark" echo "=========================================================================" -{ - # We always build with Hive because the PySpark Spark SQL tests need it. - BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" +CURRENT_BLOCK=$BLOCK_BUILD - echo "[info] Building Spark with these arguments: $BUILD_MVN_PROFILE_ARGS" +{ # NOTE: echo "q" is needed because sbt on encountering a build file with failure #+ (either resolution or compilation) prompts the user for input either q, r, etc #+ to quit or retry. This echo is there to make it not block. - # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a + # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a #+ single argument! # QUESTION: Why doesn't 'yes "q"' work? # QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work? + # First build with 0.12 to ensure patches do not break the hive 12 build + HIVE_12_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver -Phive-0.12.0" + echo "[info] Compile with hive 0.12" + echo -e "q\n" \ + | sbt/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \ + | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" + + # Then build with default version(0.13.1) because tests are based on this version + echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS"\ + " -Phive -Phive-thriftserver" echo -e "q\n" \ - | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly \ + | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" } @@ -141,17 +167,19 @@ echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" +CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS + { # If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled. # This must be a single argument, as it is. if [ -n "$_RUN_SQL_TESTS" ]; then - SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive" + SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver" fi if [ -n "$_SQL_TESTS_ONLY" ]; then # This must be an array of individual arguments. Otherwise, having one long string #+ will be interpreted as a single test, which doesn't work. - SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test") + SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test") else SBT_MAVEN_TEST_ARGS=("test") fi @@ -175,10 +203,16 @@ echo "" echo "=========================================================================" echo "Running PySpark tests" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS + ./python/run-tests echo "" echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" + +CURRENT_BLOCK=$BLOCK_MIMA + ./dev/mima diff --git a/dev/run-tests-codes.sh b/dev/run-tests-codes.sh new file mode 100644 index 0000000000000..1348e0609dda4 --- /dev/null +++ b/dev/run-tests-codes.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +readonly BLOCK_GENERAL=10 +readonly BLOCK_RAT=11 +readonly BLOCK_SCALA_STYLE=12 +readonly BLOCK_PYTHON_STYLE=13 +readonly BLOCK_BUILD=14 +readonly BLOCK_SPARK_UNIT_TESTS=15 +readonly BLOCK_PYSPARK_UNIT_TESTS=16 +readonly BLOCK_MIMA=17 diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins index 0b1e31b9413cf..6a849e4f77207 100755 --- a/dev/run-tests-jenkins +++ b/dev/run-tests-jenkins @@ -26,9 +26,23 @@ FWDIR="$(cd `dirname $0`/..; pwd)" cd "$FWDIR" +source "$FWDIR/dev/run-tests-codes.sh" + COMMENTS_URL="https://api.github.com/repos/apache/spark/issues/$ghprbPullId/comments" PULL_REQUEST_URL="https://github.com/apache/spark/pull/$ghprbPullId" +# Important Environment Variables +# --- +# $ghprbActualCommit +#+ This is the hash of the most recent commit in the PR. +#+ The merge-base of this and master is the commit from which the PR was branched. +# $sha1 +#+ If the patch merges cleanly, this is a reference to the merge commit hash +#+ (e.g. "origin/pr/2606/merge"). +#+ If the patch does not merge cleanly, it is equal to $ghprbActualCommit. +#+ The merge-base of this and master in the case of a clean merge is the most recent commit +#+ against master. + COMMIT_URL="https://github.com/apache/spark/commit/${ghprbActualCommit}" # GitHub doesn't auto-link short hashes when submitted via the API, unfortunately. :( SHORT_COMMIT_HASH="${ghprbActualCommit:0:7}" @@ -39,9 +53,9 @@ function post_message () { local message=$1 local data="{\"body\": \"$message\"}" local HTTP_CODE_HEADER="HTTP Response Code: " - + echo "Attempting to post to Github..." - + local curl_output=$( curl `#--dump-header -` \ --silent \ @@ -61,12 +75,12 @@ function post_message () { echo " > data: ${data}" >&2 # exit $curl_status fi - + local api_response=$( echo "${curl_output}" \ | grep -v -e "^${HTTP_CODE_HEADER}" ) - + local http_code=$( echo "${curl_output}" \ | grep -e "^${HTTP_CODE_HEADER}" \ @@ -78,60 +92,97 @@ function post_message () { echo " > api_response: ${api_response}" >&2 echo " > data: ${data}" >&2 fi - + if [ "$curl_status" -eq 0 ] && [ "$http_code" -eq "201" ]; then echo " > Post successful." fi } +function send_archived_logs () { + echo "Archiving unit tests logs..." + + local log_files=$( + find .\ + -name "unit-tests.log" -o\ + -path "./sql/hive/target/HiveCompatibilitySuite.failed" -o\ + -path "./sql/hive/target/HiveCompatibilitySuite.hiveFailed" -o\ + -path "./sql/hive/target/HiveCompatibilitySuite.wrong" + ) + + if [ -z "$log_files" ]; then + echo "> No log files found." >&2 + else + local log_archive="unit-tests-logs.tar.gz" + echo "$log_files" | xargs tar czf ${log_archive} + + local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER} + local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive}) + local scp_status="$?" + + if [ "$scp_status" -ne 0 ]; then + echo "Failed to send archived unit tests logs to Jenkins master." >&2 + echo "> scp_status: ${scp_status}" >&2 + echo "> scp_output: ${scp_output}" >&2 + else + echo "> Send successful." + fi + + rm -f ${log_archive} + fi +} + + +# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR +#+ and not anything else added to master since the PR was branched. + # check PR merge-ability and check for new public classes { if [ "$sha1" == "$ghprbActualCommit" ]; then - merge_note=" * This patch **does not** merge cleanly!" + merge_note=" * This patch **does not merge cleanly**." else merge_note=" * This patch merges cleanly." + fi - source_files=$( - git diff master... --name-only `# diff patch against master from branch point` \ - | grep -v -e "\/test" `# ignore files in test directories` \ - | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ - | tr "\n" " " - ) - new_public_classes=$( - git diff master... ${source_files} `# diff patch against master from branch point` \ - | grep "^\+" `# filter in only added lines` \ - | sed -r -e "s/^\+//g" `# remove the leading +` \ - | grep -e "trait " -e "class " `# filter in lines with these key words` \ - | grep -e "{" -e "(" `# filter in lines with these key words, too` \ - | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ - | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ - | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ - | tr -d "\n" `# remove actual LF characters` - ) - - if [ "$new_public_classes" == "" ]; then - public_classes_note=" * This patch adds no public classes." - else - public_classes_note=" * This patch adds the following public classes _(experimental)_:" - public_classes_note="${public_classes_note}\n${new_public_classes}" - fi + source_files=$( + git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ + | grep -v -e "\/test" `# ignore files in test directories` \ + | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ + | tr "\n" " " + ) + new_public_classes=$( + git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ + | grep "^\+" `# filter in only added lines` \ + | sed -r -e "s/^\+//g" `# remove the leading +` \ + | grep -e "trait " -e "class " `# filter in lines with these key words` \ + | grep -e "{" -e "(" `# filter in lines with these key words, too` \ + | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ + | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ + | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ + | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | tr -d "\n" `# remove actual LF characters` + ) + + if [ -z "$new_public_classes" ]; then + public_classes_note=" * This patch adds no public classes." + else + public_classes_note=" * This patch adds the following public classes _(experimental)_:" + public_classes_note="${public_classes_note}\n${new_public_classes}" fi } # post start message { start_message="\ - [QA tests have started](${BUILD_URL}consoleFull) for \ + [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \ PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." - + start_message="${start_message}\n${merge_note}" # start_message="${start_message}\n${public_classes_note}" - + post_message "$start_message" } @@ -141,25 +192,45 @@ function post_message () { test_result="$?" if [ "$test_result" -eq "124" ]; then - fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** \ + fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}consoleFull)** \ for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \ after a configured wait of \`${TESTS_TIMEOUT}\`." post_message "$fail_message" exit $test_result + elif [ "$test_result" -eq "0" ]; then + test_result_note=" * This patch **passes all tests**." else - if [ "$test_result" -eq "0" ]; then - test_result_note=" * This patch **passes** unit tests." + if [ "$test_result" -eq "$BLOCK_GENERAL" ]; then + failing_test="some tests" + elif [ "$test_result" -eq "$BLOCK_RAT" ]; then + failing_test="RAT tests" + elif [ "$test_result" -eq "$BLOCK_SCALA_STYLE" ]; then + failing_test="Scala style tests" + elif [ "$test_result" -eq "$BLOCK_PYTHON_STYLE" ]; then + failing_test="Python style tests" + elif [ "$test_result" -eq "$BLOCK_BUILD" ]; then + failing_test="to build" + elif [ "$test_result" -eq "$BLOCK_SPARK_UNIT_TESTS" ]; then + failing_test="Spark unit tests" + elif [ "$test_result" -eq "$BLOCK_PYSPARK_UNIT_TESTS" ]; then + failing_test="PySpark unit tests" + elif [ "$test_result" -eq "$BLOCK_MIMA" ]; then + failing_test="MiMa tests" else - test_result_note=" * This patch **fails** unit tests." + failing_test="some tests" fi + + test_result_note=" * This patch **fails $failing_test**." fi + + send_archived_logs } # post end message { result_message="\ - [QA tests have finished](${BUILD_URL}consoleFull) for \ + [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}consoleFull) for \ PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})." result_message="${result_message}\n${test_result_note}" diff --git a/dev/scalastyle b/dev/scalastyle index efb5f291ea3b7..c3c6012e74ffa 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,7 +17,7 @@ # limitations under the License. # -echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt +echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt # Check style with YARN alpha built too echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt @@ -25,7 +25,9 @@ echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn- echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \ >> scalastyle.txt -ERRORS=$(cat scalastyle.txt | grep -e "\") +ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}') +rm scalastyle.txt + if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" exit 1 diff --git a/docs/README.md b/docs/README.md index 79708c3df9106..119484038083f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -25,8 +25,7 @@ installing via the Ruby Gem dependency manager. Since the exact HTML output varies between versions of Jekyll and its dependencies, we list specific versions here in some cases: - $ sudo gem install jekyll -v 1.4.3 - $ sudo gem uninstall kramdown -v 1.4.1 + $ sudo gem install jekyll $ sudo gem install jekyll-redirect-from Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory @@ -44,7 +43,7 @@ You can modify the default Jekyll build as follows: ## Pygments We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages, -so you will also need to install that (it requires Python) by running `sudo easy_install Pygments`. +so you will also need to install that (it requires Python) by running `sudo pip install Pygments`. To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile phase, use the following sytax: @@ -54,19 +53,24 @@ phase, use the following sytax: // supported languages too. {% endhighlight %} -## API Docs (Scaladoc and Epydoc) +## Sphinx + +We use Sphinx to generate Python API docs, so you will need to install it by running +`sudo pip install sphinx`. + +## API Docs (Scaladoc and Sphinx) You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory. -Similarly, you can build just the PySpark epydoc by running `epydoc --config epydoc.conf` from the -SPARK_PROJECT_ROOT/pyspark directory. Documentation is only generated for classes that are listed as +Similarly, you can build just the PySpark docs by running `make html` from the +SPARK_PROJECT_ROOT/python/docs directory. Documentation is only generated for classes that are listed as public in `__init__.py`. When you run `jekyll` in the `docs` directory, it will also copy over the scaladoc for the various Spark subprojects into the `docs` directory (and then also into the `_site` directory). We use a jekyll plugin to run `sbt/sbt doc` before building the site so if you haven't run it (recently) it may take some time as it generates all of the scaladoc. The jekyll plugin also generates the -PySpark docs using [epydoc](http://epydoc.sourceforge.net/). +PySpark docs [Sphinx](http://sphinx-doc.org/). NOTE: To skip the step of building and copying over the Scala and Python API docs, run `SKIP_API=1 jekyll`. diff --git a/docs/_config.yml b/docs/_config.yml index 7bc3a78e2d265..a96a76dd9ab5e 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -8,10 +8,13 @@ gems: kramdown: entity_output: numeric -# These allow the documentation to be updated with nerw releases +include: + - _static + +# These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.0.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.0.0 +SPARK_VERSION: 1.3.0-SNAPSHOT +SPARK_VERSION_SHORT: 1.3.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.18.1 diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 3b02e090aec28..4566a2fff562b 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -63,19 +63,20 @@ puts "cp -r " + source + "/. " + dest cp_r(source + "/.", dest) - # Build Epydoc for Python - puts "Moving to python directory and building epydoc." - cd("../python") - puts `epydoc --config epydoc.conf` + # Build Sphinx docs for Python - puts "Moving back into docs dir." - cd("../docs") + puts "Moving to python/docs directory and building sphinx." + cd("../python/docs") + puts `make html` + + puts "Moving back into home dir." + cd("../../") puts "Making directory api/python" - mkdir_p "api/python" + mkdir_p "docs/api/python" - puts "cp -r ../python/docs/. api/python" - cp_r("../python/docs/.", "api/python") + puts "cp -r python/docs/_build/html/. docs/api/python" + cp_r("python/docs/_build/html/.", "docs/api/python") cd("..") end diff --git a/docs/building-spark.md b/docs/building-spark.md index 901c157162fee..40a47410e683a 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -67,11 +67,13 @@ For Apache Hadoop 2.x, 0.23.x, Cloudera CDH, and other Hadoop versions with YARN
    YARN versionProfile required
    0.23.x to 2.1.xyarn-alpha
    0.23.x to 2.1.xyarn-alpha (Deprecated.)
    2.2.x and lateryarn
    +Note: Support for YARN-alpha API's will be removed in Spark 1.3 (see SPARK-3445). + Examples: {% highlight bash %} @@ -90,8 +92,11 @@ mvn -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 -DskipTests clean package # Apache Hadoop 2.3.X mvn -Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0 -DskipTests clean package -# Apache Hadoop 2.4.X -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -DskipTests clean package +# Apache Hadoop 2.4.X or 2.5.X +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=VERSION -DskipTests clean package + +Versions of Hadoop after 2.5.X may or may not work with the -Phadoop-2.4 profile (they were +released after this version of Spark). # Different versions of HDFS and YARN. mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -DskipTests clean package @@ -99,20 +104,34 @@ mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -Dski # Building With Hive and JDBC Support To enable Hive integration for Spark SQL along with its JDBC server and CLI, -add the `-Phive` profile to your existing build options. +add the `-Phive` and `Phive-thriftserver` profiles to your existing build options. +By default Spark will build with Hive 0.13.1 bindings. You can also build for +Hive 0.12.0 using the `-Phive-0.12.0` profile. {% highlight bash %} -# Apache Hadoop 2.4.X with Hive support -mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package +# Apache Hadoop 2.4.X with Hive 13 support +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-thriftserver -DskipTests clean package + +# Apache Hadoop 2.4.X with Hive 12 support +mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -Phive-0.12.0 -Phive-thriftserver -DskipTests clean package {% endhighlight %} +# Building for Scala 2.11 +To produce a Spark package compiled with Scala 2.11, use the `-Dscala-2.11` property: + + mvn -Pyarn -Phadoop-2.4 -Dscala-2.11 -DskipTests clean package + +Scala 2.11 support in Spark is experimental and does not support a few features. +Specifically, Spark's external Kafka library and JDBC component are not yet +supported in Scala 2.11 builds. + # Spark Tests in Maven Tests are run by default via the [ScalaTest Maven plugin](http://www.scalatest.org/user_guide/using_the_scalatest_maven_plugin). Some of the tests require Spark to be packaged first, so always run `mvn package` with `-DskipTests` the first time. The following is an example of a correct (build, test) sequence: - mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive clean package - mvn -Pyarn -Phadoop-2.3 -Phive test + mvn -Pyarn -Phadoop-2.3 -DskipTests -Phive -Phive-thriftserver clean package + mvn -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test The ScalaTest plugin also supports running only a specific test suite as follows: @@ -171,10 +190,25 @@ can be set to control the SBT build. For example: sbt/sbt -Pyarn -Phadoop-2.3 assembly +# Testing with SBT + +Some of the tests require Spark to be packaged first, so always run `sbt/sbt assembly` the first time. The following is an example of a correct (build, test) sequence: + + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver assembly + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver test + +To run only a specific test suite as follows: + + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver "test-only org.apache.spark.repl.ReplSuite" + +To run test suites of a specific sub project as follows: + + sbt/sbt -Pyarn -Phadoop-2.3 -Phive -Phive-thriftserver core/test + # Speeding up Compilation with Zinc [Zinc](https://github.com/typesafehub/zinc) is a long-running server version of SBT's incremental compiler. When run locally as a background process, it speeds up builds of Scala-based projects like Spark. Developers who regularly recompile Spark with Maven will be the most interested in Zinc. The project site gives instructions for building and running `zinc`; OS X users can -install it using `brew install zinc`. \ No newline at end of file +install it using `brew install zinc`. diff --git a/docs/configuration.md b/docs/configuration.md index 1c33855365170..0b77f5ab645c9 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -21,16 +21,22 @@ application. These properties can be set directly on a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) passed to your `SparkContext`. `SparkConf` allows you to configure some of the common properties (e.g. master URL and application name), as well as arbitrary key-value pairs through the -`set()` method. For example, we could initialize an application as follows: +`set()` method. For example, we could initialize an application with two threads as follows: + +Note that we run with local[2], meaning two threads - which represents "minimal" parallelism, +which can help detect bugs that only exist when we run in a distributed context. {% highlight scala %} val conf = new SparkConf() - .setMaster("local") + .setMaster("local[2]") .setAppName("CountingSheep") .set("spark.executor.memory", "1g") val sc = new SparkContext(conf) {% endhighlight %} +Note that we can have more than 1 thread in local mode, and in cases like spark streaming, we may actually +require one to prevent any sort of starvation issues. + ## Dynamically Loading Spark Properties In some cases, you may want to avoid hard-coding certain configurations in a `SparkConf`. For instance, if you'd like to run the same application with different masters or different @@ -46,7 +52,7 @@ Then, you can supply configuration values at runtime: --conf "spark.executor.extraJavaOptions=-XX:+PrintGCDetails -XX:+PrintGCTimeStamps" myApp.jar {% endhighlight %} -The Spark shell and [`spark-submit`](cluster-overview.html#launching-applications-with-spark-submit) +The Spark shell and [`spark-submit`](submitting-applications.html) tool support two ways to load configurations dynamically. The first are command line options, such as `--master`, as shown above. `spark-submit` can accept any Spark property using the `--conf` flag, but uses special flags for properties that play a part in launching the Spark application. @@ -103,6 +109,26 @@ of the most common options to set are: (e.g. 512m, 2g). + + spark.driver.memory + 512m + + Amount of memory to use for the driver process, i.e. where SparkContext is initialized. + (e.g. 512m, 2g). + + + + spark.driver.maxResultSize + 1g + + Limit of total size of serialized results of all partitions for each Spark action (e.g. collect). + Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size + is above this limit. + Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory + and memory overhead of objects in JVM). Setting a proper limit can protect the driver from + out-of-memory errors. + + spark.serializer org.apache.spark.serializer.
    JavaSerializer @@ -116,12 +142,23 @@ of the most common options to set are: org.apache.spark.Serializer. + + spark.kryo.classesToRegister + (none) + + If you use Kryo serialization, give a comma-separated list of custom class names to register + with Kryo. + See the tuning guide for more details. + + spark.kryo.registrator (none) - If you use Kryo serialization, set this class to register your custom classes with Kryo. - It should be set to a class that extends + If you use Kryo serialization, set this class to register your custom classes with Kryo. This + property is useful if you need to register your classes in a custom way, e.g. to specify a custom + field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be + set to a class that extends KryoRegistrator. See the tuning guide for more details. @@ -153,14 +190,6 @@ Apart from these, the following properties are also available, and may be useful #### Runtime Environment - - - - - @@ -195,6 +224,7 @@ Apart from these, the following properties are also available, and may be useful (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading classes in Executors. This feature can be used to mitigate conflicts between Spark's dependencies and user dependencies. It is currently an experimental feature. + (Currently, this setting does not work for YARN, see SPARK-2996 for more details). @@ -348,6 +378,16 @@ Apart from these, the following properties are also available, and may be useful map-side aggregation and there are at most this many reduce partitions. + + + + +
    Property NameDefaultMeaning
    spark.executor.memory512m - Amount of memory to use per executor process, in the same format as JVM memory strings - (e.g. 512m, 2g). -
    spark.executor.extraJavaOptions (none)
    spark.shuffle.blockTransferServicenetty + Implementation to use for transferring shuffle and cached blocks between executors. There + are two implementations available: netty and nio. Netty-based + block transfer is intended to be simpler but equally efficient and is the default option + starting in 1.2. +
    #### Spark UI @@ -357,14 +397,23 @@ Apart from these, the following properties are also available, and may be useful spark.ui.port 4040 - Port for your application's dashboard, which shows memory and workload data + Port for your application's dashboard, which shows memory and workload data. spark.ui.retainedStages 1000 - How many stages the Spark UI remembers before garbage collecting. + How many stages the Spark UI and status APIs remember before garbage + collecting. + + + + spark.ui.retainedJobs + 1000 + + How many stages the Spark UI and status APIs remember before garbage + collecting. @@ -514,6 +563,9 @@ Apart from these, the following properties are also available, and may be useful spark.default.parallelism + For distributed shuffle operations like reduceByKey and join, the + largest number of partitions in a parent RDD. For operations like parallelize + with no parent RDDs, it depends on the cluster manager:
    • Local mode: number of cores on the local machine
    • Mesos fine grained mode: 8
    • @@ -521,8 +573,8 @@ Apart from these, the following properties are also available, and may be useful
    - Default number of tasks to use across the cluster for distributed shuffle operations - (groupByKey, reduceByKey, etc) when not set by user. + Default number of partitions in RDDs returned by transformations like join, + reduceByKey, and parallelize when not set by user. @@ -619,6 +671,15 @@ Apart from these, the following properties are also available, and may be useful output directories. We recommend that users do not disable this except if trying to achieve compatibility with previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand. + + spark.hadoop.cloneConf + false + If set to true, clones a new Hadoop Configuration object for each task. This + option should be enabled to work around Configuration thread-safety issues (see + SPARK-2546 for more details). + This is disabled by default in order to avoid unexpected performance regressions for jobs that + are not affected by these issues. + spark.executor.heartbeatInterval 10000 @@ -717,7 +778,7 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.pauses - 600 + 6000 This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause @@ -872,8 +933,8 @@ Apart from these, the following properties are also available, and may be useful spark.scheduler.revive.interval 1000 - The interval length for the scheduler to revive the worker resource offers to run tasks. - (in milliseconds) + The interval length for the scheduler to revive the worker resource offers to run tasks + (in milliseconds). @@ -885,7 +946,7 @@ Apart from these, the following properties are also available, and may be useful to wait for before scheduling begins. Specified as a double between 0 and 1. Regardless of whether the minimum ratio of resources has been reached, the maximum amount of time it will wait before scheduling begins is controlled by config - spark.scheduler.maxRegisteredResourcesWaitingTime + spark.scheduler.maxRegisteredResourcesWaitingTime. diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index 530798f2b8022..66bf5f1a855ed 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -12,16 +12,14 @@ on the [Amazon Web Services site](http://aws.amazon.com/). `spark-ec2` is designed to manage multiple named clusters. You can launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster -launches a set of instances, which are tagged with the cluster name, -and placed into EC2 security groups. If you don't specify a security -group, the `spark-ec2` script will create security groups based on the -cluster name you request. For example, a cluster named +shutdown an existing cluster, or log into a cluster. Each cluster is +identified by placing its machines into EC2 security groups whose names +are derived from the name of the cluster. For example, a cluster named `test` will contain a master node in a security group called `test-master`, and a number of slave nodes in a security group called -`test-slaves`. You can also specify a security group prefix to be used -in place of the cluster name. Machines in a cluster can be identified -by looking for the "Name" tag of the instance in the Amazon EC2 Console. +`test-slaves`. The `spark-ec2` script will create these security groups +for you based on the cluster name you request. You can also use them to +identify machines belonging to each cluster in the Amazon EC2 Console. # Before You Start diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index fdb9f98e214e5..e298c51f8a5b7 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -6,6 +6,47 @@ title: GraphX Programming Guide * This will become a table of contents (this text will be scraped). {:toc} + + +[EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +[Edge]: api/scala/index.html#org.apache.spark.graphx.Edge +[EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet +[Graph]: api/scala/index.html#org.apache.spark.graphx.Graph +[GraphOps]: api/scala/index.html#org.apache.spark.graphx.GraphOps +[Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED] +[Graph.reverse]: api/scala/index.html#org.apache.spark.graphx.Graph@reverse:Graph[VD,ED] +[Graph.subgraph]: api/scala/index.html#org.apache.spark.graphx.Graph@subgraph((EdgeTriplet[VD,ED])⇒Boolean,(VertexId,VD)⇒Boolean):Graph[VD,ED] +[Graph.mask]: api/scala/index.html#org.apache.spark.graphx.Graph@mask[VD2,ED2](Graph[VD2,ED2])(ClassTag[VD2],ClassTag[ED2]):Graph[VD,ED] +[Graph.groupEdges]: api/scala/index.html#org.apache.spark.graphx.Graph@groupEdges((ED,ED)⇒ED):Graph[VD,ED] +[GraphOps.joinVertices]: api/scala/index.html#org.apache.spark.graphx.GraphOps@joinVertices[U](RDD[(VertexId,U)])((VertexId,VD,U)⇒VD)(ClassTag[U]):Graph[VD,ED] +[Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] +[Graph.aggregateMessages]: api/scala/index.html#org.apache.spark.graphx.Graph@aggregateMessages[A]((EdgeContext[VD,ED,A])⇒Unit,(A,A)⇒A,TripletFields)(ClassTag[A]):VertexRDD[A] +[EdgeContext]: api/scala/index.html#org.apache.spark.graphx.EdgeContext +[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A] +[GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]] +[GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]] +[RDD Persistence]: programming-guide.html#rdd-persistence +[Graph.cache]: api/scala/index.html#org.apache.spark.graphx.Graph@cache():Graph[VD,ED] +[GraphOps.pregel]: api/scala/index.html#org.apache.spark.graphx.GraphOps@pregel[A](A,Int,EdgeDirection)((VertexId,VD,A)⇒VD,(EdgeTriplet[VD,ED])⇒Iterator[(VertexId,A)],(A,A)⇒A)(ClassTag[A]):Graph[VD,ED] +[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy$ +[GraphLoader.edgeListFile]: api/scala/index.html#org.apache.spark.graphx.GraphLoader$@edgeListFile(SparkContext,String,Boolean,Int):Graph[Int,Int] +[Graph.apply]: api/scala/index.html#org.apache.spark.graphx.Graph$@apply[VD,ED](RDD[(VertexId,VD)],RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] +[Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] +[Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] +[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy +[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED] +[PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ +[ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ +[TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ +[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph@partitionBy(PartitionStrategy):Graph[VD,ED] +[EdgeContext.sendToSrc]: api/scala/index.html#org.apache.spark.graphx.EdgeContext@sendToSrc(msg:A):Unit +[EdgeContext.sendToDst]: api/scala/index.html#org.apache.spark.graphx.EdgeContext@sendToDst(msg:A):Unit +[TripletFields]: api/java/org/apache/spark/graphx/TripletFields.html +[TripletFields.All]: api/java/org/apache/spark/graphx/TripletFields.html#All +[TripletFields.None]: api/java/org/apache/spark/graphx/TripletFields.html#None +[TripletFields.Src]: api/java/org/apache/spark/graphx/TripletFields.html#Src +[TripletFields.Dst]: api/java/org/apache/spark/graphx/TripletFields.html#Dst +

    - Data-Parallel vs. Graph-Parallel - -

    +1. To improve performance we have introduced a new version of +[`mapReduceTriplets`][Graph.mapReduceTriplets] called +[`aggregateMessages`][Graph.aggregateMessages] which takes the messages previously returned from +[`mapReduceTriplets`][Graph.mapReduceTriplets] through a callback ([`EdgeContext`][EdgeContext]) +rather than by return value. +We are deprecating [`mapReduceTriplets`][Graph.mapReduceTriplets] and encourage users to consult +the [transition guide](#mrTripletsTransition). -However, the same restrictions that enable these substantial performance gains also make it -difficult to express many of the important stages in a typical graph-analytics pipeline: -constructing the graph, modifying its structure, or expressing computation that spans multiple -graphs. Furthermore, how we look at data depends on our objectives and the same raw data may have -many different table and graph views. - -

    - Tables and Graphs - -

    - -As a consequence, it is often necessary to be able to move between table and graph views of the same -physical data and to leverage the properties of each view to easily and efficiently express -computation. However, existing graph analytics pipelines must compose graph-parallel and data- -parallel systems, leading to extensive data movement and duplication and a complicated programming -model. - -

    - Graph Analytics Pipeline - -

    - -The goal of the GraphX project is to unify graph-parallel and data-parallel computation in one -system with a single composable API. The GraphX API enables users to view data both as a graph and -as collections (i.e., RDDs) without data movement or duplication. By incorporating recent advances -in graph-parallel systems, GraphX is able to optimize the execution of graph operations. - -## GraphX Replaces the Spark Bagel API - -Prior to the release of GraphX, graph computation in Spark was expressed using Bagel, an -implementation of Pregel. GraphX improves upon Bagel by exposing a richer property graph API, a -more streamlined version of the Pregel abstraction, and system optimizations to improve performance -and reduce memory overhead. While we plan to eventually deprecate Bagel, we will continue to -support the [Bagel API](api/scala/index.html#org.apache.spark.bagel.package) and -[Bagel programming guide](bagel-programming-guide.html). However, we encourage Bagel users to -explore the new GraphX API and comment on issues that may complicate the transition from Bagel. - -## Migrating from Spark 0.9.1 - -GraphX in Spark {{site.SPARK_VERSION}} contains one user-facing interface change from Spark 0.9.1. [`EdgeRDD`][EdgeRDD] may now store adjacent vertex attributes to construct the triplets, so it has gained a type parameter. The edges of a graph of type `Graph[VD, ED]` are of type `EdgeRDD[ED, VD]` rather than `EdgeRDD[ED]`. - -[EdgeRDD]: api/scala/index.html#org.apache.spark.graphx.EdgeRDD +2. In Spark 1.0 and 1.1, the type signature of [`EdgeRDD`][EdgeRDD] switched from +`EdgeRDD[ED]` to `EdgeRDD[ED, VD]` to enable some caching optimizations. We have since discovered +a more elegant solution and have restored the type signature to the more natural `EdgeRDD[ED]` type. # Getting Started @@ -108,9 +96,10 @@ import org.apache.spark.rdd.RDD If you are not using the Spark shell you will also need a `SparkContext`. To learn more about getting started with Spark refer to the [Spark Quick Start Guide](quick-start.html). -# The Property Graph +# The Property Graph + The [property graph](api/scala/index.html#org.apache.spark.graphx.Graph) is a directed multigraph with user defined objects attached to each vertex and edge. A directed multigraph is a directed graph with potentially multiple parallel edges sharing the same source and destination vertex. The @@ -123,7 +112,7 @@ identifiers. The property graph is parameterized over the vertex (`VD`) and edge (`ED`) types. These are the types of the objects associated with each vertex and edge respectively. -> GraphX optimizes the representation of vertex and edge types when they are plain old data-types +> GraphX optimizes the representation of vertex and edge types when they are primitive data types > (e.g., int, double, etc...) reducing the in memory footprint by storing them in specialized > arrays. @@ -142,8 +131,8 @@ var graph: Graph[VertexProperty, String] = null Like RDDs, property graphs are immutable, distributed, and fault-tolerant. Changes to the values or structure of the graph are accomplished by producing a new graph with the desired changes. Note that substantial parts of the original graph (i.e., unaffected structure, attributes, and indicies) -are reused in the new graph reducing the cost of this inherently functional data-structure. The -graph is partitioned across the executors using a range of vertex-partitioning heuristics. As with +are reused in the new graph reducing the cost of this inherently functional data structure. The +graph is partitioned across the executors using a range of vertex partitioning heuristics. As with RDDs, each partition of the graph can be recreated on a different machine in the event of a failure. Logically the property graph corresponds to a pair of typed collections (RDDs) encoding the @@ -153,12 +142,12 @@ the vertices and edges of the graph: {% highlight scala %} class Graph[VD, ED] { val vertices: VertexRDD[VD] - val edges: EdgeRDD[ED, VD] + val edges: EdgeRDD[ED] } {% endhighlight %} -The classes `VertexRDD[VD]` and `EdgeRDD[ED, VD]` extend and are optimized versions of `RDD[(VertexID, -VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED, VD]` provide additional +The classes `VertexRDD[VD]` and `EdgeRDD[ED]` extend and are optimized versions of `RDD[(VertexID, +VD)]` and `RDD[Edge[ED]]` respectively. Both `VertexRDD[VD]` and `EdgeRDD[ED]` provide additional functionality built around graph computation and leverage internal optimizations. We discuss the `VertexRDD` and `EdgeRDD` API in greater detail in the section on [vertex and edge RDDs](#vertex_and_edge_rdds) but for now they can be thought of as simply RDDs of the form: @@ -211,7 +200,6 @@ In the above example we make use of the [`Edge`][Edge] case class. Edges have a `dstId` corresponding to the source and destination vertex identifiers. In addition, the `Edge` class has an `attr` member which stores the edge property. -[Edge]: api/scala/index.html#org.apache.spark.graphx.Edge We can deconstruct a graph into the respective vertex and edge views by using the `graph.vertices` and `graph.edges` members respectively. @@ -237,7 +225,6 @@ The triplet view logically joins the vertex and edge properties yielding an `RDD[EdgeTriplet[VD, ED]]` containing instances of the [`EdgeTriplet`][EdgeTriplet] class. This *join* can be expressed in the following SQL expression: -[EdgeTriplet]: api/scala/index.html#org.apache.spark.graphx.EdgeTriplet {% highlight sql %} SELECT src.id, dst.id, src.attr, e.attr, dst.attr @@ -278,9 +265,6 @@ core operators are defined in [`GraphOps`][GraphOps]. However, thanks to Scala operators in `GraphOps` are automatically available as members of `Graph`. For example, we can compute the in-degree of each vertex (defined in `GraphOps`) by the following: -[Graph]: api/scala/index.html#org.apache.spark.graphx.Graph -[GraphOps]: api/scala/index.html#org.apache.spark.graphx.GraphOps - {% highlight scala %} val graph: Graph[(String, String), String] // Use the implicit GraphOps.inDegrees operator @@ -310,7 +294,7 @@ class Graph[VD, ED] { val degrees: VertexRDD[Int] // Views of the graph as collections ============================================================= val vertices: VertexRDD[VD] - val edges: EdgeRDD[ED, VD] + val edges: EdgeRDD[ED] val triplets: RDD[EdgeTriplet[VD, ED]] // Functions for caching graphs ================================================================== def persist(newLevel: StorageLevel = StorageLevel.MEMORY_ONLY): Graph[VD, ED] @@ -341,10 +325,10 @@ class Graph[VD, ED] { // Aggregate information about adjacent triplets ================================================= def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] - def mapReduceTriplets[A: ClassTag]( - mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexID, A)], - reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None) + def aggregateMessages[Msg: ClassTag]( + sendMsg: EdgeContext[VD, ED, Msg] => Unit, + mergeMsg: (Msg, Msg) => Msg, + tripletFields: TripletFields = TripletFields.All) : VertexRDD[A] // Iterative graph-parallel computation ========================================================== def pregel[A](initialMsg: A, maxIterations: Int, activeDirection: EdgeDirection)( @@ -363,8 +347,7 @@ class Graph[VD, ED] { ## Property Operators -In direct analogy to the RDD `map` operator, the property -graph contains the following: +Like the RDD `map` operator, the property graph contains the following: {% highlight scala %} class Graph[VD, ED] { @@ -377,7 +360,7 @@ class Graph[VD, ED] { Each of these operators yields a new graph with the vertex or edge properties modified by the user defined `map` function. -> Note that in all cases the graph structure is unaffected. This is a key feature of these operators +> Note that in each case the graph structure is unaffected. This is a key feature of these operators > which allows the resulting graph to reuse the structural indices of the original graph. The > following snippets are logically equivalent, but the first one does not preserve the structural > indices and would not benefit from the GraphX system optimizations: @@ -390,14 +373,13 @@ val newGraph = Graph(newVertices, graph.edges) val newGraph = graph.mapVertices((id, attr) => mapUdf(id, attr)) {% endhighlight %} -[Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED] These operators are often used to initialize the graph for a particular computation or project away -unnecessary properties. For example, given a graph with the out-degrees as the vertex properties +unnecessary properties. For example, given a graph with the out degrees as the vertex properties (we describe how to construct such a graph later), we initialize it for PageRank: {% highlight scala %} -// Given a graph where the vertex property is the out-degree +// Given a graph where the vertex property is the out degree val inputGraph: Graph[Int, String] = graph.outerJoinVertices(graph.outDegrees)((vid, _, degOpt) => degOpt.getOrElse(0)) // Construct a graph where each edge contains the weight @@ -406,9 +388,10 @@ val outputGraph: Graph[Double, Double] = inputGraph.mapTriplets(triplet => 1.0 / triplet.srcAttr).mapVertices((id, _) => 1.0) {% endhighlight %} -## Structural Operators +## Structural Operators + Currently GraphX supports only a simple set of commonly used structural operators and we expect to add more in the future. The following is a list of the basic structural operators. @@ -425,9 +408,8 @@ class Graph[VD, ED] { The [`reverse`][Graph.reverse] operator returns a new graph with all the edge directions reversed. This can be useful when, for example, trying to compute the inverse PageRank. Because the reverse operation does not modify vertex or edge properties or change the number of edges, it can be -implemented efficiently without data-movement or duplication. +implemented efficiently without data movement or duplication. -[Graph.reverse]: api/scala/index.html#org.apache.spark.graphx.Graph@reverse:Graph[VD,ED] The [`subgraph`][Graph.subgraph] operator takes vertex and edge predicates and returns the graph containing only the vertices that satisfy the vertex predicate (evaluate to true) and edges that @@ -435,7 +417,6 @@ satisfy the edge predicate *and connect vertices that satisfy the vertex predica operator can be used in number of situations to restrict the graph to the vertices and edges of interest or eliminate broken links. For example in the following code we remove broken links: -[Graph.subgraph]: api/scala/index.html#org.apache.spark.graphx.Graph@subgraph((EdgeTriplet[VD,ED])⇒Boolean,(VertexId,VD)⇒Boolean):Graph[VD,ED] {% highlight scala %} // Create an RDD for the vertices @@ -469,13 +450,12 @@ validGraph.triplets.map( > Note in the above example only the vertex predicate is provided. The `subgraph` operator defaults > to `true` if the vertex or edge predicates are not provided. -The [`mask`][Graph.mask] operator also constructs a subgraph by returning a graph that contains the +The [`mask`][Graph.mask] operator constructs a subgraph by returning a graph that contains the vertices and edges that are also found in the input graph. This can be used in conjunction with the `subgraph` operator to restrict a graph based on the properties in another related graph. For example, we might run connected components using the graph with missing vertices and then restrict the answer to the valid subgraph. -[Graph.mask]: api/scala/index.html#org.apache.spark.graphx.Graph@mask[VD2,ED2](Graph[VD2,ED2])(ClassTag[VD2],ClassTag[ED2]):Graph[VD,ED] {% highlight scala %} // Run Connected Components @@ -490,10 +470,9 @@ The [`groupEdges`][Graph.groupEdges] operator merges parallel edges (i.e., dupli pairs of vertices) in the multigraph. In many numerical applications, parallel edges can be *added* (their weights combined) into a single edge thereby reducing the size of the graph. -[Graph.groupEdges]: api/scala/index.html#org.apache.spark.graphx.Graph@groupEdges((ED,ED)⇒ED):Graph[VD,ED] + ## Join Operators - In many cases it is necessary to join data from external collections (RDDs) with graphs. For example, we might have extra user properties that we want to merge with an existing graph or we @@ -514,10 +493,8 @@ returns a new graph with the vertex properties obtained by applying the user def to the result of the joined vertices. Vertices without a matching value in the RDD retain their original value. -[GraphOps.joinVertices]: api/scala/index.html#org.apache.spark.graphx.GraphOps@joinVertices[U](RDD[(VertexId,U)])((VertexId,VD,U)⇒VD)(ClassTag[U]):Graph[VD,ED] - -> Note that if the RDD contains more than one value for a given vertex only one will be used. It -> is therefore recommended that the input RDD be first made unique using the following which will +> Note that if the RDD contains more than one value for a given vertex only one will be used. It +> is therefore recommended that the input RDD be made unique using the following which will > also *pre-index* the resulting values to substantially accelerate the subsequent join. > {% highlight scala %} val nonUniqueCosts: RDD[(VertexID, Double)] @@ -533,8 +510,6 @@ property type. Because not all vertices may have a matching value in the input function takes an `Option` type. For example, we can setup a graph for PageRank by initializing vertex properties with their `outDegree`. -[Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] - {% highlight scala %} val outDegrees: VertexRDD[Int] = graph.outDegrees @@ -555,65 +530,76 @@ val joinedGraph = graph.joinVertices(uniqueCosts, (id: VertexID, oldCost: Double, extraCost: Double) => oldCost + extraCost) {% endhighlight %} +> + + ## Neighborhood Aggregation -A key part of graph computation is aggregating information about the neighborhood of each vertex. -For example we might want to know the number of followers each user has or the average age of the +A key step in may graph analytics tasks is aggregating information about the neighborhood of each +vertex. +For example, we might want to know the number of followers each user has or the average age of the the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and connected components) repeatedly aggregate properties of neighboring vertices (e.g., current PageRank Value, shortest path to the source, and smallest reachable vertex id). -### Map Reduce Triplets (mapReduceTriplets) - +> To improve performance the primary aggregation operator changed from +`graph.mapReduceTriplets` to the new `graph.AggregateMessages`. While the changes in the API are +relatively small, we provide a transition guide below. -[Graph.mapReduceTriplets]: api/scala/index.html#org.apache.spark.graphx.Graph@mapReduceTriplets[A](mapFunc:org.apache.spark.graphx.EdgeTriplet[VD,ED]=>Iterator[(org.apache.spark.graphx.VertexId,A)],reduceFunc:(A,A)=>A,activeSetOpt:Option[(org.apache.spark.graphx.VertexRDD[_],org.apache.spark.graphx.EdgeDirection)])(implicitevidence$10:scala.reflect.ClassTag[A]):org.apache.spark.graphx.VertexRDD[A] + -The core (heavily optimized) aggregation primitive in GraphX is the -[`mapReduceTriplets`][Graph.mapReduceTriplets] operator: +### Aggregate Messages (aggregateMessages) + +The core aggregation operation in GraphX is [`aggregateMessages`][Graph.aggregateMessages]. +This operator applies a user defined `sendMsg` function to each edge triplet in the graph +and then uses the `mergeMsg` function to aggregate those messages at their destination vertex. {% highlight scala %} class Graph[VD, ED] { - def mapReduceTriplets[A]( - map: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], - reduce: (A, A) => A) - : VertexRDD[A] + def aggregateMessages[Msg: ClassTag]( + sendMsg: EdgeContext[VD, ED, Msg] => Unit, + mergeMsg: (Msg, Msg) => Msg, + tripletFields: TripletFields = TripletFields.All) + : VertexRDD[Msg] } {% endhighlight %} -The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which -is applied to each triplet and can yield *messages* destined to either (none or both) vertices in -the triplet. To facilitate optimized pre-aggregation, we currently only support messages destined -to the source or destination vertex of the triplet. The user defined `reduce` function combines the -messages destined to each vertex. The `mapReduceTriplets` operator returns a `VertexRDD[A]` -containing the aggregate message (of type `A`) destined to each vertex. Vertices that do not +The user defined `sendMsg` function takes an [`EdgeContext`][EdgeContext], which exposes the +source and destination attributes along with the edge attribute and functions +([`sendToSrc`][EdgeContext.sendToSrc], and [`sendToDst`][EdgeContext.sendToDst]) to send +messages to the source and destination attributes. Think of `sendMsg` as the map +function in map-reduce. +The user defined `mergeMsg` function takes two messages destined to the same vertex and +yields a single message. Think of `mergeMsg` as the reduce function in map-reduce. +The [`aggregateMessages`][Graph.aggregateMessages] operator returns a `VertexRDD[Msg]` +containing the aggregate message (of type `Msg`) destined to each vertex. Vertices that did not receive a message are not included in the returned `VertexRDD`. -
    - -

    Note that mapReduceTriplets takes an additional optional activeSet -(not shown above see API docs for details) which restricts the map phase to edges adjacent to the -vertices in the provided VertexRDD:

    - -{% highlight scala %} - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None -{% endhighlight %} - -

    The EdgeDirection specifies which edges adjacent to the vertex set are included in the map -phase. If the direction is In, then the user defined map function will -only be run only on edges with the destination vertex in the active set. If the direction is -Out, then the map function will only be run only on edges originating from -vertices in the active set. If the direction is Either, then the map -function will be run only on edges with either vertex in the active set. If the direction is -Both, then the map function will be run only on edges with both vertices -in the active set. The active set must be derived from the set of vertices in the graph. -Restricting computation to triplets adjacent to a subset of the vertices is often necessary in -incremental iterative computation and is a key part of the GraphX implementation of Pregel.

    - -
    - -In the following example we use the `mapReduceTriplets` operator to compute the average age of the -more senior followers of each user. + + +In addition, [`aggregateMessages`][Graph.aggregateMessages] takes an optional +`tripletsFields` which indicates what data is accessed in the [`EdgeContext`][EdgeContext] +(i.e., the source vertex attribute but not the destination vertex attribute). +The possible options for the `tripletsFields` are defined in [`TripletFields`][TripletFields] and +the default value is [`TripletFields.All`][TripletFields.All] which indicates that the user +defined `sendMsg` function may access any of the fields in the [`EdgeContext`][EdgeContext]. +The `tripletFields` argument can be used to notify GraphX that only part of the +[`EdgeContext`][EdgeContext] will be needed allowing GraphX to select an optimized join strategy. +For example if we are computing the average age of the followers of each user we would only require +the source field and so we would use [`TripletFields.Src`][TripletFields.Src] to indicate that we +only require the source field + +> In earlier versions of GraphX we used byte code inspection to infer the +[`TripletFields`][TripletFields] however we have found that bytecode inspection to be +slightly unreliable and instead opted for more explicit user control. + +In the following example we use the [`aggregateMessages`][Graph.aggregateMessages] operator to +compute the average age of the more senior followers of each user. {% highlight scala %} // Import random graph generation library @@ -622,14 +608,11 @@ import org.apache.spark.graphx.util.GraphGenerators val graph: Graph[Double, Int] = GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble ) // Compute the number of older followers and their total age -val olderFollowers: VertexRDD[(Int, Double)] = graph.mapReduceTriplets[(Int, Double)]( +val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)]( triplet => { // Map Function if (triplet.srcAttr > triplet.dstAttr) { // Send message to destination vertex containing counter and age - Iterator((triplet.dstId, (1, triplet.srcAttr))) - } else { - // Don't send a message for this triplet - Iterator.empty + triplet.sendToDst(1, triplet.srcAttr) } }, // Add counter and age @@ -642,10 +625,57 @@ val avgAgeOfOlderFollowers: VertexRDD[Double] = avgAgeOfOlderFollowers.collect.foreach(println(_)) {% endhighlight %} -> Note that the `mapReduceTriplets` operation performs optimally when the messages (and the sums of -> messages) are constant sized (e.g., floats and addition instead of lists and concatenation). More -> precisely, the result of `mapReduceTriplets` should ideally be sub-linear in the degree of each -> vertex. +> The `aggregateMessages` operation performs optimally when the messages (and the sums of +> messages) are constant sized (e.g., floats and addition instead of lists and concatenation). + + + +### Map Reduce Triplets Transition Guide (Legacy) + +In earlier versions of GraphX we neighborhood aggregation was accomplished using the +[`mapReduceTriplets`][Graph.mapReduceTriplets] operator: + +{% highlight scala %} +class Graph[VD, ED] { + def mapReduceTriplets[Msg]( + map: EdgeTriplet[VD, ED] => Iterator[(VertexId, Msg)], + reduce: (Msg, Msg) => Msg) + : VertexRDD[Msg] +} +{% endhighlight %} + +The [`mapReduceTriplets`][Graph.mapReduceTriplets] operator takes a user defined map function which +is applied to each triplet and can yield *messages* which are aggregated using the user defined +`reduce` function. +However, we found the user of the returned iterator to be expensive and it inhibited our ability to +apply additional optimizations (e.g., local vertex renumbering). +In [`aggregateMessages`][Graph.aggregateMessages] we introduced the EdgeContext which exposes the +triplet fields and also functions to explicitly send messages to the source and destination vertex. +Furthermore we removed bytecode inspection and instead require the user to indicate what fields +in the triplet are actually required. + +The following code block using `mapReduceTriplets`: + +{% highlight scala %} +val graph: Graph[Int, Float] = ... +def msgFun(triplet: Triplet[Int, Float]): Iterator[(Int, String)] = { + Iterator((triplet.dstId, "Hi")) +} +def reduceFun(a: Int, b: Int): Int = a + b +val result = graph.mapReduceTriplets[String](msgFun, reduceFun) +{% endhighlight %} + +can be rewritten using `aggregateMessages` as: + +{% highlight scala %} +val graph: Graph[Int, Float] = ... +def msgFun(triplet: EdgeContext[Int, Float, String]) { + triplet.sendToDst("Hi") +} +def reduceFun(a: Int, b: Int): Int = a + b +val result = graph.aggregateMessages[String](msgFun, reduceFun) +{% endhighlight %} + ### Computing Degree Information @@ -673,10 +703,6 @@ attributes at each vertex. This can be easily accomplished using the [`collectNeighborIds`][GraphOps.collectNeighborIds] and the [`collectNeighbors`][GraphOps.collectNeighbors] operators. -[GraphOps.collectNeighborIds]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighborIds(EdgeDirection):VertexRDD[Array[VertexId]] -[GraphOps.collectNeighbors]: api/scala/index.html#org.apache.spark.graphx.GraphOps@collectNeighbors(EdgeDirection):VertexRDD[Array[(VertexId,VD)]] - - {% highlight scala %} class GraphOps[VD, ED] { def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] @@ -684,36 +710,34 @@ class GraphOps[VD, ED] { } {% endhighlight %} -> Note that these operators can be quite costly as they duplicate information and require +> These operators can be quite costly as they duplicate information and require > substantial communication. If possible try expressing the same computation using the -> `mapReduceTriplets` operator directly. +> [`aggregateMessages`][Graph.aggregateMessages] operator directly. ## Caching and Uncaching In Spark, RDDs are not persisted in memory by default. To avoid recomputation, they must be explicitly cached when using them multiple times (see the [Spark Programming Guide][RDD Persistence]). Graphs in GraphX behave the same way. **When using a graph multiple times, make sure to call [`Graph.cache()`][Graph.cache] on it first.** -[RDD Persistence]: programming-guide.html#rdd-persistence -[Graph.cache]: api/scala/index.html#org.apache.spark.graphx.Graph@cache():Graph[VD,ED] In iterative computations, *uncaching* may also be necessary for best performance. By default, cached RDDs and graphs will remain in memory until memory pressure forces them to be evicted in LRU order. For iterative computation, intermediate results from previous iterations will fill up the cache. Though they will eventually be evicted, the unnecessary data stored in memory will slow down garbage collection. It would be more efficient to uncache intermediate results as soon as they are no longer necessary. This involves materializing (caching and forcing) a graph or RDD every iteration, uncaching all other datasets, and only using the materialized dataset in future iterations. However, because graphs are composed of multiple RDDs, it can be difficult to unpersist them correctly. **For iterative computation we recommend using the Pregel API, which correctly unpersists intermediate results.** -# Pregel API -Graphs are inherently recursive data-structures as properties of vertices depend on properties of +# Pregel API + +Graphs are inherently recursive data structures as properties of vertices depend on properties of their neighbors which in turn depend on properties of *their* neighbors. As a consequence many important graph algorithms iteratively recompute the properties of each vertex until a fixed-point condition is reached. A range of graph-parallel abstractions have been proposed -to express these iterative algorithms. GraphX exposes a Pregel-like operator which is a fusion of -the widely used Pregel and GraphLab abstractions. +to express these iterative algorithms. GraphX exposes a variant of the Pregel API. -At a high-level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction -*constrained to the topology of the graph*. The Pregel operator executes in a series of super-steps -in which vertices receive the *sum* of their inbound messages from the previous super- step, compute +At a high level the Pregel operator in GraphX is a bulk-synchronous parallel messaging abstraction +*constrained to the topology of the graph*. The Pregel operator executes in a series of super steps +in which vertices receive the *sum* of their inbound messages from the previous super step, compute a new value for the vertex property, and then send messages to neighboring vertices in the next -super-step. Unlike Pregel and instead more like GraphLab messages are computed in parallel as a +super step. Unlike Pregel, messages are computed in parallel as a function of the edge triplet and the message computation has access to both the source and -destination vertex attributes. Vertices that do not receive a message are skipped within a super- +destination vertex attributes. Vertices that do not receive a message are skipped within a super step. The Pregel operators terminates iteration and returns the final graph when there are no messages remaining. @@ -724,8 +748,6 @@ messages remaining. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* of its implementation (note calls to graph.cache have been removed): -[GraphOps.pregel]: api/scala/index.html#org.apache.spark.graphx.GraphOps@pregel[A](A,Int,EdgeDirection)((VertexId,VD,A)⇒VD,(EdgeTriplet[VD,ED])⇒Iterator[(VertexId,A)],(A,A)⇒A)(ClassTag[A]):Graph[VD,ED] - {% highlight scala %} class GraphOps[VD, ED] { def pregel[A] @@ -795,9 +817,10 @@ val sssp = initialGraph.pregel(Double.PositiveInfinity)( println(sssp.vertices.collect.mkString("\n")) {% endhighlight %} -# Graph Builders +# Graph Builders + GraphX provides several ways of building a graph from a collection of vertices and edges in an RDD or on disk. None of the graph builders repartitions the graph's edges by default; instead, edges are left in their default partitions (such as their original blocks in HDFS). [`Graph.groupEdges`][Graph.groupEdges] requires the graph to be repartitioned because it assumes identical edges will be colocated on the same partition, so you must call [`Graph.partitionBy`][Graph.partitionBy] before calling `groupEdges`. {% highlight scala %} @@ -848,18 +871,12 @@ object Graph { [`Graph.fromEdgeTuples`][Graph.fromEdgeTuples] allows creating a graph from only an RDD of edge tuples, assigning the edges the value 1, and automatically creating any vertices mentioned by edges and assigning them the default value. It also supports deduplicating the edges; to deduplicate, pass `Some` of a [`PartitionStrategy`][PartitionStrategy] as the `uniqueEdges` parameter (for example, `uniqueEdges = Some(PartitionStrategy.RandomVertexCut)`). A partition strategy is necessary to colocate identical edges on the same partition so they can be deduplicated. -[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy$ - -[GraphLoader.edgeListFile]: api/scala/index.html#org.apache.spark.graphx.GraphLoader$@edgeListFile(SparkContext,String,Boolean,Int):Graph[Int,Int] -[Graph.apply]: api/scala/index.html#org.apache.spark.graphx.Graph$@apply[VD,ED](RDD[(VertexId,VD)],RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] -[Graph.fromEdgeTuples]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdgeTuples[VD](RDD[(VertexId,VertexId)],VD,Option[PartitionStrategy])(ClassTag[VD]):Graph[VD,Int] -[Graph.fromEdges]: api/scala/index.html#org.apache.spark.graphx.Graph$@fromEdges[VD,ED](RDD[Edge[ED]],VD)(ClassTag[VD],ClassTag[ED]):Graph[VD,ED] + # Vertex and Edge RDDs - GraphX exposes `RDD` views of the vertices and edges stored within the graph. However, because -GraphX maintains the vertices and edges in optimized data-structures and these data-structures +GraphX maintains the vertices and edges in optimized data structures and these data structures provide additional functionality, the vertices and edges are returned as `VertexRDD` and `EdgeRDD` respectively. In this section we review some of the additional useful functionality in these types. @@ -870,7 +887,7 @@ The `VertexRDD[A]` extends `RDD[(VertexID, A)]` and adds the additional constrai attribute of type `A`. Internally, this is achieved by storing the vertex attributes in a reusable hash-map data-structure. As a consequence if two `VertexRDD`s are derived from the same base `VertexRDD` (e.g., by `filter` or `mapValues`) they can be joined in constant time without hash -evaluations. To leverage this indexed data-structure, the `VertexRDD` exposes the following +evaluations. To leverage this indexed data structure, the `VertexRDD` exposes the following additional functionality: {% highlight scala %} @@ -893,7 +910,7 @@ class VertexRDD[VD] extends RDD[(VertexID, VD)] { Notice, for example, how the `filter` operator returns an `VertexRDD`. Filter is actually implemented using a `BitSet` thereby reusing the index and preserving the ability to do fast joins with other `VertexRDD`s. Likewise, the `mapValues` operators do not allow the `map` function to -change the `VertexID` thereby enabling the same `HashMap` data-structures to be reused. Both the +change the `VertexID` thereby enabling the same `HashMap` data structures to be reused. Both the `leftJoin` and `innerJoin` are able to identify when joining two `VertexRDD`s derived from the same `HashMap` and implement the join by linear scan rather than costly point lookups. @@ -916,21 +933,19 @@ val setC: VertexRDD[Double] = setA.innerJoin(setB)((id, a, b) => a + b) ## EdgeRDDs -The `EdgeRDD[ED, VD]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one +The `EdgeRDD[ED]`, which extends `RDD[Edge[ED]]` organizes the edges in blocks partitioned using one of the various partitioning strategies defined in [`PartitionStrategy`][PartitionStrategy]. Within each partition, edge attributes and adjacency structure, are stored separately enabling maximum reuse when changing attribute values. -[PartitionStrategy]: api/scala/index.html#org.apache.spark.graphx.PartitionStrategy - The three additional functions exposed by the `EdgeRDD` are: {% highlight scala %} // Transform the edge attributes while preserving the structure -def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] +def mapValues[ED2](f: Edge[ED] => ED2): EdgeRDD[ED2] // Revere the edges reusing both attributes and structure -def reverse: EdgeRDD[ED, VD] +def reverse: EdgeRDD[ED] // Join two `EdgeRDD`s partitioned using the same partitioning strategy. -def innerJoin[ED2, ED3](other: EdgeRDD[ED2, VD])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] +def innerJoin[ED2, ED3](other: EdgeRDD[ED2])(f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] {% endhighlight %} In most applications we have found that operations on the `EdgeRDD` are accomplished through the @@ -960,7 +975,6 @@ the [`Graph.partitionBy`][Graph.partitionBy] operator. The default partitioning the initial partitioning of the edges as provided on graph construction. However, users can easily switch to 2D-partitioning or other heuristics included in GraphX. -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph$@partitionBy(partitionStrategy:org.apache.spark.graphx.PartitionStrategy):org.apache.spark.graphx.Graph[VD,ED]

    +# Graph Algorithms + GraphX includes a set of graph algorithms to simplify analytics tasks. The algorithms are contained in the `org.apache.spark.graphx.lib` package and can be accessed directly as methods on `Graph` via [`GraphOps`][GraphOps]. This section describes the algorithms and how they are used. -## PageRank +## PageRank + PageRank measures the importance of each vertex in a graph, assuming an edge from *u* to *v* represents an endorsement of *v*'s importance by *u*. For example, if a Twitter user is followed by many others, the user will be ranked highly. GraphX comes with static and dynamic implementations of PageRank as methods on the [`PageRank` object][PageRank]. Static PageRank runs for a fixed number of iterations, while dynamic PageRank runs until the ranks converge (i.e., stop changing by more than a specified tolerance). [`GraphOps`][GraphOps] allows calling these algorithms directly as methods on `Graph`. GraphX also includes an example social network dataset that we can run PageRank on. A set of users is given in `graphx/data/users.txt`, and a set of relationships between users is given in `graphx/data/followers.txt`. We compute the PageRank of each user as follows: -[PageRank]: api/scala/index.html#org.apache.spark.graphx.lib.PageRank$ - {% highlight scala %} // Load the edges as a graph val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt") @@ -1014,8 +1028,6 @@ println(ranksByUsername.collect().mkString("\n")) The connected components algorithm labels each connected component of the graph with the ID of its lowest-numbered vertex. For example, in a social network, connected components can approximate clusters. GraphX contains an implementation of the algorithm in the [`ConnectedComponents` object][ConnectedComponents], and we compute the connected components of the example social network dataset from the [PageRank section](#pagerank) as follows: -[ConnectedComponents]: api/scala/index.html#org.apache.spark.graphx.lib.ConnectedComponents$ - {% highlight scala %} // Load the graph as in the PageRank example val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt") @@ -1037,9 +1049,6 @@ println(ccByUsername.collect().mkString("\n")) A vertex is part of a triangle when it has two adjacent vertices with an edge between them. GraphX implements a triangle counting algorithm in the [`TriangleCount` object][TriangleCount] that determines the number of triangles passing through each vertex, providing a measure of clustering. We compute the triangle count of the social network dataset from the [PageRank section](#pagerank). *Note that `TriangleCount` requires the edges to be in canonical orientation (`srcId < dstId`) and the graph to be partitioned using [`Graph.partitionBy`][Graph.partitionBy].* -[TriangleCount]: api/scala/index.html#org.apache.spark.graphx.lib.TriangleCount$ -[Graph.partitionBy]: api/scala/index.html#org.apache.spark.graphx.Graph@partitionBy(PartitionStrategy):Graph[VD,ED] - {% highlight scala %} // Load the edges in canonical order and partition the graph for triangle count val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt", true).partitionBy(PartitionStrategy.RandomVertexCut) diff --git a/docs/img/data_parallel_vs_graph_parallel.png b/docs/img/data_parallel_vs_graph_parallel.png deleted file mode 100644 index d3918f01d8f3b..0000000000000 Binary files a/docs/img/data_parallel_vs_graph_parallel.png and /dev/null differ diff --git a/docs/img/graph_analytics_pipeline.png b/docs/img/graph_analytics_pipeline.png deleted file mode 100644 index 6d606e01894ae..0000000000000 Binary files a/docs/img/graph_analytics_pipeline.png and /dev/null differ diff --git a/docs/img/tables_and_graphs.png b/docs/img/tables_and_graphs.png deleted file mode 100644 index ec37bb45a62f0..0000000000000 Binary files a/docs/img/tables_and_graphs.png and /dev/null differ diff --git a/docs/index.md b/docs/index.md index edd622ec90f64..171d6ddad62f3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -112,6 +112,7 @@ options for deployment: **External Resources:** * [Spark Homepage](http://spark.apache.org) +* [Spark Wiki](https://cwiki.apache.org/confluence/display/SPARK) * [Mailing Lists](http://spark.apache.org/mailing-lists.html): ask questions about Spark here * [AMP Camps](http://ampcamp.berkeley.edu/): a series of training camps at UC Berkeley that featured talks and exercises about Spark, Spark Streaming, Mesos, and more. [Videos](http://ampcamp.berkeley.edu/3/), diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index d10bd63746629..c696ae9c8e8c8 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. -## Examples +### Examples

    @@ -69,7 +69,7 @@ println("Within Set Sum of Squared Errors = " + WSSSE) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given below: {% highlight java %} @@ -113,12 +113,6 @@ public class KMeansExample { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
    @@ -153,3 +147,103 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
    + +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +Quick Start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. + +## Streaming clustering + +When data arrive in a stream, we may want to estimate clusters dynamically, +updating them as new data arrive. MLlib provides support for streaming k-means clustering, +with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm +uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign +all points to their nearest cluster, compute new cluster centers, then update each cluster using: + +`\begin{equation} + c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t} +\end{equation}` +`\begin{equation} + n_{t+1} = n_t + m_t +\end{equation}` + +Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned +to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$` +is the number of points added to the cluster in the current batch. The decay factor `$\alpha$` +can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning; +with `$\alpha$=0` only the most recent data will be used. This is analogous to an +exponentially-weighted moving average. + +The decay can be specified using a `halfLife` parameter, which determines the +correct decay factor `a` such that, for data acquired +at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5. +The unit of time can be specified either as `batches` or `points` and the update rule +will be adjusted accordingly. + +### Examples + +This example shows how to estimate clusters on streaming data. + +
    + +
    + +First we import the neccessary classes. + +{% highlight scala %} + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.clustering.StreamingKMeans + +{% endhighlight %} + +Then we make an input stream of vectors for training, as well as a stream of labeled data +points for testing. We assume a StreamingContext `ssc` has been created, see +[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info. + +{% highlight scala %} + +val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse) +val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) + +{% endhighlight %} + +We create a model with random clusters and specify the number of clusters to find + +{% highlight scala %} + +val numDimensions = 3 +val numClusters = 2 +val model = new StreamingKMeans() + .setK(numClusters) + .setDecayFactor(1.0) + .setRandomCenters(numDimensions, 0.0) + +{% endhighlight %} + +Now register the streams for training and testing and start the job, printing +the predicted cluster assignments on new data points as they arrive. + +{% highlight scala %} + +model.trainOn(trainingData) +model.predictOnValues(testData).print() + +ssc.start() +ssc.awaitTermination() + +{% endhighlight %} + +As you add new text files with data the cluster centers will update. Each training +point should be formatted as `[x1, x2, x3]`, and each test data point +should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier +(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` +you will see predictions. With new data, the cluster centers will change! + +
    + +
    diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index d5c539db791be..2094963392295 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -110,7 +110,7 @@ val model = ALS.trainImplicit(ratings, rank, numIterations, alpha) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} @@ -184,12 +184,6 @@ public class CollaborativeFiltering { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
    @@ -229,6 +223,12 @@ model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01)
    +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +Quick Start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. + ## Tutorial The [training exercises](https://databricks-training.s3.amazonaws.com/index.html) from the Spark Summit 2014 include a hands-on tutorial for diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 21cb35b4270ca..870fed6cc5024 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -121,9 +121,9 @@ public class SVD { The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark +In order to run the above application, follow the instructions +provided in the [Self-Contained +Applications](quick-start.html#self-contained-applications) section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. @@ -200,10 +200,11 @@ public class PCA { } {% endhighlight %} -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. + +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 1511ae6dda4ed..197bc77d506c6 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -83,7 +83,7 @@ val idf = new IDF().fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} -MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +MLlib's IDF implementation provides an option for ignoring terms which occur in less than a minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature can be used by passing the `minDocFreq` value to the IDF constructor. @@ -95,8 +95,49 @@ tf.cache() val idf = new IDF(minDocFreq = 2).fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} + +
    + +TF and IDF are implemented in [HashingTF](api/python/pyspark.mllib.html#pyspark.mllib.feature.HashingTF) +and [IDF](api/python/pyspark.mllib.html#pyspark.mllib.feature.IDF). +`HashingTF` takes an RDD of list as the input. +Each record could be an iterable of strings or other types. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.feature import HashingTF + +sc = SparkContext() +# Load documents (one per line). +documents = sc.textFile("...").map(lambda line: line.split(" ")) + +hashingTF = HashingTF() +tf = hashingTF.transform(documents) +{% endhighlight %} + +While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes: +first to compute the IDF vector and second to scale the term frequencies by IDF. +{% highlight python %} +from pyspark.mllib.feature import IDF + +# ... continue from the previous example +tf.cache() +idf = IDF().fit(tf) +tfidf = idf.transform(tf) +{% endhighlight %} + +MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature +can be used by passing the `minDocFreq` value to the IDF constructor. + +{% highlight python %} +# ... continue from the previous example +tf.cache() +idf = IDF(minDocFreq=2).fit(tf) +tfidf = idf.transform(tf) +{% endhighlight %}
    @@ -162,6 +203,23 @@ for((synonym, cosineSimilarity) <- synonyms) { } {% endhighlight %} +
    +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.feature import Word2Vec + +sc = SparkContext(appName='Word2Vec') +inp = sc.textFile("text8_lines").map(lambda row: row.split(" ")) + +word2vec = Word2Vec() +model = word2vec.fit(inp) + +synonyms = model.findSynonyms('china', 40) + +for word, cosine_distance in synonyms: + print "{}: {}".format(word, cosine_distance) +{% endhighlight %} +
    ## StandardScaler @@ -223,6 +281,29 @@ val data1 = data.map(x => (x.label, scaler1.transform(x.features))) val data2 = data.map(x => (x.label, scaler2.transform(Vectors.dense(x.features.toArray)))) {% endhighlight %} + +
    +{% highlight python %} +from pyspark.mllib.util import MLUtils +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.feature import StandardScaler + +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +label = data.map(lambda x: x.label) +features = data.map(lambda x: x.features) + +scaler1 = StandardScaler().fit(features) +scaler2 = StandardScaler(withMean=True, withStd=True).fit(features) + +# data1 will be unit variance. +data1 = label.zip(scaler1.transform(features)) + +# Without converting the features into dense vectors, transformation with zero mean will raise +# exception on sparse vector. +# data2 will be unit variance and zero mean. +data2 = label.zip(scaler1.transform(features.map(lambda x: Vectors.dense(x.toArray())))) +{% endhighlight %} +
    ## Normalizer @@ -267,4 +348,25 @@ val data1 = data.map(x => (x.label, normalizer1.transform(x.features))) val data2 = data.map(x => (x.label, normalizer2.transform(x.features))) {% endhighlight %} + +
    +{% highlight python %} +from pyspark.mllib.util import MLUtils +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.feature import Normalizer + +data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +labels = data.map(lambda x: x.label) +features = data.map(lambda x: x.features) + +normalizer1 = Normalizer() +normalizer2 = Normalizer(p=float("inf")) + +# Each sample in data1 will be normalized using $L^2$ norm. +data1 = labels.zip(normalizer1.transform(features)) + +# Each sample in data2 will be normalized using $L^\infty$ norm. +data2 = labels.zip(normalizer2.transform(features)) +{% endhighlight %} +
    diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index d31bec3e1bd01..bc914a1899801 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -247,7 +247,7 @@ val modelL1 = svmAlg.run(training) All of MLlib's methods use Java-friendly types, so you can import and call them there the same way you do in Scala. The only caveat is that the methods take Scala RDD objects, while the Spark Java API uses a separate `JavaRDD` class. You can convert a Java RDD to a Scala one by -calling `.rdd()` on your `JavaRDD` object. A standalone application example +calling `.rdd()` on your `JavaRDD` object. A self-contained application example that is equivalent to the provided example in Scala is given bellow: {% highlight java %} @@ -323,9 +323,9 @@ svmAlg.optimizer() final SVMModel modelL1 = svmAlg.run(training.rdd()); {% endhighlight %} -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark +In order to run the above application, follow the instructions +provided in the [Self-Contained +Applications](quick-start.html#self-contained-applications) section of the Spark quick-start guide. Be sure to also include *spark-mllib* to your build file as a dependency. @@ -482,12 +482,6 @@ public class LinearRegression { } } {% endhighlight %} - -In order to run the above standalone application, follow the instructions -provided in the [Standalone -Applications](quick-start.html#standalone-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency.
    @@ -519,6 +513,12 @@ print("Mean Squared Error = " + str(MSE))
    +In order to run the above application, follow the instructions +provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) +section of the Spark +quick-start guide. Be sure to also include *spark-mllib* to your build file as +a dependency. + ## Streaming linear regression When data arrive in a streaming fashion, it is useful to fit regression models online, diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md index 7f9d4c6563944..d5b044d94fdd7 100644 --- a/docs/mllib-naive-bayes.md +++ b/docs/mllib-naive-bayes.md @@ -88,11 +88,11 @@ JavaPairRDD predictionAndLabel = return new Tuple2(model.predict(p.features()), p.label()); } }); -double accuracy = 1.0 * predictionAndLabel.filter(new Function, Boolean>() { +double accuracy = predictionAndLabel.filter(new Function, Boolean>() { @Override public Boolean call(Tuple2 pl) { - return pl._1() == pl._2(); + return pl._1().equals(pl._2()); } - }).count() / test.count(); + }).count() / (double) test.count(); {% endhighlight %} diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md index c4632413991f1..ca8c29218f52d 100644 --- a/docs/mllib-statistics.md +++ b/docs/mllib-statistics.md @@ -197,7 +197,7 @@ print Statistics.corr(data, method="pearson") ## Stratified sampling -Unlike the other statistics functions, which reside in MLLib, stratified sampling methods, +Unlike the other statistics functions, which reside in MLlib, stratified sampling methods, `sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified sampling, the keys can be thought of as a label and the value as a specific attribute. For example the key can be man or woman, or document ids, and the respective values can be the list of ages @@ -380,6 +380,46 @@ for (ChiSqTestResult result : featureTestResults) { {% endhighlight %} +
    +[`Statistics`](api/python/index.html#pyspark.mllib.stat.Statistics$) provides methods to +run Pearson's chi-squared tests. The following example demonstrates how to run and interpret +hypothesis tests. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.mllib.linalg import Vectors, Matrices +from pyspark.mllib.regresssion import LabeledPoint +from pyspark.mllib.stat import Statistics + +sc = SparkContext() + +vec = Vectors.dense(...) # a vector composed of the frequencies of events + +# compute the goodness of fit. If a second vector to test against is not supplied as a parameter, +# the test runs against a uniform distribution. +goodnessOfFitTestResult = Statistics.chiSqTest(vec) +print goodnessOfFitTestResult # summary of the test including the p-value, degrees of freedom, + # test statistic, the method used, and the null hypothesis. + +mat = Matrices.dense(...) # a contingency matrix + +# conduct Pearson's independence test on the input contingency matrix +independenceTestResult = Statistics.chiSqTest(mat) +print independenceTestResult # summary of the test including the p-value, degrees of freedom... + +obs = sc.parallelize(...) # LabeledPoint(feature, label) . + +# The contingency table is constructed from an RDD of LabeledPoint and used to conduct +# the independence test. Returns an array containing the ChiSquaredTestResult for every feature +# against the label. +featureTestResults = Statistics.chiSqTest(obs) + +for i, result in enumerate(featureTestResults): + print "Column $d:" % (i + 1) + print result +{% endhighlight %} +
    + ## Random data generation diff --git a/docs/monitoring.md b/docs/monitoring.md index d07ec4a57a2cc..f32cdef240d31 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -77,6 +77,13 @@ follows: one implementation, provided by Spark, which looks for application logs stored in the file system. + + spark.history.fs.logDirectory + file:/tmp/spark-events + + Directory that contains application event logs to be loaded by the history server + + spark.history.fs.updateInterval 10 diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 8e8cc1dd983f8..c60de6e970531 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -117,6 +117,8 @@ The first thing a Spark program must do is to create a [SparkContext](api/scala/ how to access a cluster. To create a `SparkContext` you first need to build a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object that contains information about your application. +Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before creating a new one. + {% highlight scala %} val conf = new SparkConf().setAppName(appName).setMaster(master) new SparkContext(conf) @@ -211,17 +213,17 @@ For a complete list of options, run `pyspark --help`. Behind the scenes, It is also possible to launch the PySpark shell in [IPython](http://ipython.org), the enhanced Python interpreter. PySpark works with IPython 1.0.0 and later. To -use IPython, set the `PYSPARK_PYTHON` variable to `ipython` when running `bin/pyspark`: +use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running `bin/pyspark`: {% highlight bash %} -$ PYSPARK_PYTHON=ipython ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark {% endhighlight %} -You can customize the `ipython` command by setting `PYSPARK_PYTHON_OPTS`. For example, to launch +You can customize the `ipython` command by setting `PYSPARK_DRIVER_PYTHON_OPTS`. For example, to launch the [IPython Notebook](http://ipython.org/notebook.html) with PyLab plot support: {% highlight bash %} -$ PYSPARK_PYTHON=ipython PYSPARK_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark +$ PYSPARK_DRIVER_PYTHON=ipython PYSPARK_DRIVER_PYTHON_OPTS="notebook --pylab inline" ./bin/pyspark {% endhighlight %} @@ -1131,7 +1133,7 @@ method. The code below shows this: {% highlight scala %} scala> val broadcastVar = sc.broadcast(Array(1, 2, 3)) -broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c) +broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0) scala> broadcastVar.value res0: Array[Int] = Array(1, 2, 3) @@ -1304,6 +1306,12 @@ vecAccum = sc.accumulator(Vector(...), VectorAccumulatorParam()) +For accumulator updates performed inside actions only, Spark guarantees that each task's update to the accumulator +will only be applied once, i.e. restarted tasks will not update the value. In transformations, users should be aware +of that each task's update may be applied more than once if tasks or job stages are re-executed. + + + # Deploying to a Cluster The [application submission guide](submitting-applications.html) describes how to submit applications to a cluster. diff --git a/docs/quick-start.md b/docs/quick-start.md index 23313d8aa6152..bf643bb70e153 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -8,7 +8,7 @@ title: Quick Start This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive shell (in Python or Scala), -then show how to write standalone applications in Java, Scala, and Python. +then show how to write applications in Java, Scala, and Python. See the [programming guide](programming-guide.html) for a more complete reference. To follow along with this guide, first download a packaged release of Spark from the @@ -215,8 +215,8 @@ a cluster, as described in the [programming guide](programming-guide.html#initia -# Standalone Applications -Now say we wanted to write a standalone application using the Spark API. We will walk through a +# Self-Contained Applications +Now say we wanted to write a self-contained application using the Spark API. We will walk through a simple application in both Scala (with SBT), Java (with Maven), and Python.
    @@ -244,6 +244,9 @@ object SimpleApp { } {% endhighlight %} +Note that applications should define a `main()` method instead of extending `scala.App`. +Subclasses of `scala.App` may not work correctly. + This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, @@ -387,7 +390,7 @@ Lines with a: 46, Lines with b: 23
    -Now we will show how to write a standalone application using the Python API (PySpark). +Now we will show how to write an application using the Python API (PySpark). As an example, we'll create a simple Spark application, `SimpleApp.py`: diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 695813a2ba881..dfe2db4b3fce8 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -4,7 +4,7 @@ title: Running Spark on YARN --- Support for running on [YARN (Hadoop -NextGen)](http://hadoop.apache.org/docs/r2.0.2-alpha/hadoop-yarn/hadoop-yarn-site/YARN.html) +NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. # Preparations @@ -39,7 +39,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes spark.yarn.preserve.staging.files false - Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather then delete them. + Set to true to preserve the staged files (Spark jar, app jar, distributed cache files) at the end of the job rather than delete them. @@ -159,7 +159,7 @@ For example: lib/spark-examples*.jar \ 10 -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Viewing Logs" section below for how to see driver and executor logs. +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: @@ -181,7 +181,7 @@ In YARN terminology, executors and application masters run inside "containers". yarn logs -applicationId -will print out the contents of all log files from all containers from the given application. +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. diff --git a/docs/security.md b/docs/security.md index ec0523184d665..1e206a139fb72 100644 --- a/docs/security.md +++ b/docs/security.md @@ -7,7 +7,6 @@ Spark currently supports authentication via a shared secret. Authentication can * For Spark on [YARN](running-on-yarn.html) deployments, configuring `spark.authenticate` to `true` will automatically handle generating and distributing the shared secret. Each application will use a unique shared secret. * For other types of Spark deployments, the Spark parameter `spark.authenticate.secret` should be configured on each of the nodes. This secret will be used by all the Master/Workers and applications. -* **IMPORTANT NOTE:** *The experimental Netty shuffle path (`spark.shuffle.use.netty`) is not secured, so do not use Netty for shuffles if running with authentication.* ## Web UI diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 368c3d0008b07..5500da83b2b66 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -14,7 +14,7 @@ title: Spark SQL Programming Guide Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using Spark. At the core of this component is a new type of RDD, [SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of -[Row](api/scala/index.html#org.apache.spark.sql.catalyst.expressions.Row) objects, along with +[Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io) file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). @@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or spark.sql.parquet.cacheMetadata - false + true Turns on caching of Parquet schema metadata. Can speed up querying of static data. spark.sql.parquet.compression.codec - snappy + gzip Sets the compression codec use when writing Parquet files. Acceptable values include: uncompressed, snappy, gzip, lzo. + + spark.sql.hive.convertMetastoreParquet + true + + When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in + support. + + ## JSON Datasets @@ -720,7 +728,7 @@ anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD) Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. -In order to use Hive you must first run "`sbt/sbt -Phive assembly/assembly`" (or use `-Phive` for maven). +Hive support is enabled by adding the `-Phive` and `-Phive-thriftserver` flags to Spark's build. This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. @@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL Property NameDefaultMeaning spark.sql.inMemoryColumnarStorage.compressed - false + true When set to true Spark SQL will automatically select a compression codec for each column based on statistics of the data. @@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL spark.sql.inMemoryColumnarStorage.batchSize - 1000 + 10000 Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization and compression, but risk OOMs when caching data. @@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar Property NameDefaultMeaning spark.sql.autoBroadcastJoinThreshold - 10000 + 10485760 (10 MB) Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently @@ -1051,7 +1059,6 @@ in Hive deployments. **Major Hive Features** -* Spark SQL does not currently support inserting to tables using dynamic partitioning. * Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL doesn't support buckets yet. @@ -1215,7 +1222,7 @@ import org.apache.spark.sql._ DecimalType - scala.math.sql.BigDecimal + scala.math.BigDecimal DecimalType diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 5c21e912ea160..44a1f3ad7560b 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -68,7 +68,9 @@ import org.apache.spark._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -// Create a local StreamingContext with two working thread and batch interval of 1 second +// Create a local StreamingContext with two working thread and batch interval of 1 second. +// The master requires 2 cores to prevent from a starvation scenario. + val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} @@ -212,6 +214,67 @@ The complete code can be found in the Spark Streaming example [JavaNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java).
    +
    +
    +First, we import StreamingContext, which is the main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +# Create a local StreamingContext with two working thread and batch interval of 1 second +sc = SparkContext("local[2]", "NetworkWordCount") +ssc = StreamingContext(sc, 1) +{% endhighlight %} + +Using this context, we can create a DStream that represents streaming data from a TCP +source hostname, e.g. `localhost`, and port, e.g. `9999` + +{% highlight python %} +# Create a DStream that will connect to hostname:port, like localhost:9999 +lines = ssc.socketTextStream("localhost", 9999) +{% endhighlight %} + +This `lines` DStream represents the stream of data that will be received from the data +server. Each record in this DStream is a line of text. Next, we want to split the lines by +space into words. + +{% highlight python %} +# Split each line into words +words = lines.flatMap(lambda line: line.split(" ")) +{% endhighlight %} + +`flatMap` is a one-to-many DStream operation that creates a new DStream by +generating multiple new records from each record in the source DStream. In this case, +each line will be split into multiple words and the stream of words is represented as the +`words` DStream. Next, we want to count these words. + +{% highlight python %} +# Count each word in each batch +pairs = words.map(lambda word: (word, 1)) +wordCounts = pairs.reduceByKey(lambda x, y: x + y) + +# Print the first ten elements of each RDD generated in this DStream to the console +wordCounts.pprint() +{% endhighlight %} + +The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word, +1)` pairs, which is then reduced to get the frequency of words in each batch of data. +Finally, `wordCounts.pprint()` will print a few of the counts generated every second. + +Note that when these lines are executed, Spark Streaming only sets up the computation it +will perform when it is started, and no real processing has started yet. To start the processing +after all the transformations have been setup, we finally call + +{% highlight python %} +ssc.start() # Start the computation +ssc.awaitTermination() # Wait for the computation to terminate +{% endhighlight %} + +The complete code can be found in the Spark Streaming example +[NetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/network_wordcount.py). +
    +
    @@ -236,6 +299,11 @@ $ ./bin/run-example streaming.NetworkWordCount localhost 9999 $ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 {% endhighlight %} +
    +{% highlight bash %} +$ ./bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999 +{% endhighlight %} +
    @@ -259,8 +327,11 @@ hello world +
    + +
    {% highlight bash %} -# TERMINAL 2: RUNNING NetworkWordCount or JavaNetworkWordCount +# TERMINAL 2: RUNNING NetworkWordCount $ ./bin/run-example streaming.NetworkWordCount localhost 9999 ... @@ -271,6 +342,37 @@ Time: 1357008430000 ms (world,1) ... {% endhighlight %} +
    + +
    +{% highlight bash %} +# TERMINAL 2: RUNNING JavaNetworkWordCount + +$ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 +... +------------------------------------------- +Time: 1357008430000 ms +------------------------------------------- +(hello,1) +(world,1) +... +{% endhighlight %} +
    +
    +{% highlight bash %} +# TERMINAL 2: RUNNING network_wordcount.py + +$ ./bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999 +... +------------------------------------------- +Time: 2014-10-14 15:25:21 +------------------------------------------- +(hello,1) +(world,1) +... +{% endhighlight %} +
    +
    @@ -398,9 +500,34 @@ JavaSparkContext sc = ... //existing JavaSparkContext JavaStreamingContext ssc = new JavaStreamingContext(sc, new Duration(1000)); {% endhighlight %} +
    + +A [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) object can be created from a [SparkContext](api/python/pyspark.html#pyspark.SparkContext) object. + +{% highlight python %} +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +sc = SparkContext(master, appName) +ssc = StreamingContext(sc, 1) +{% endhighlight %} + +The `appName` parameter is a name for your application to show on the cluster UI. +`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls), +or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster, +you will not want to hardcode `master` in the program, +but rather [launch the application with `spark-submit`](submitting-applications.html) and +receive it there. However, for local testing and unit tests, you can pass "local[\*]" to run Spark Streaming +in-process (detects the number of cores in the local system). + +The batch interval must be set based on the latency requirements of your application +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +section for more details. +
    After a context is defined, you have to do the follow steps. + 1. Define the input sources. 1. Setup the streaming computations. 1. Start the receiving and procesing of data using `streamingContext.start()`. @@ -461,11 +588,13 @@ Every input DStream (except file stream) is associated with a single [Receiver]( A receiver is run within a Spark worker/executor as a long-running task, hence it occupies one of the cores allocated to the Spark Streaming application. Hence, it is important to remember that Spark Streaming application needs to be allocated enough cores to process the received data, as well as, to run the receiver(s). Therefore, few important points to remember are: -##### Points to remember: +##### Points to remember {:.no_toc} -- If the number of cores allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. -- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs with even one input DStream (file streams are okay) as the receiver will occupy that core and there will be no core left to process the data. - +- If the number of threads allocated to the application is less than or equal to the number of input DStreams / receivers, then the system will receive data, but not be able to process them. +- When running locally, if you master URL is set to "local", then there is only one core to run tasks. That is insufficient for programs using a DStream as the receiver (file streams are okay). So, a "local" master URL in a streaming app is generally going to cause starvation for the processor. +Thus in any streaming app, you generally will want to allocate more than one thread (i.e. set your master to "local[2]") when testing locally. +See [Spark Properties] (configuration.html#spark-properties.html). + ### Basic Sources {:.no_toc} @@ -483,6 +612,9 @@ methods for creating DStreams from files and Akka actors as input sources.
    streamingContext.fileStream(dataDirectory);
    +
    + streamingContext.textFileStream(dataDirectory) +
    Spark Streaming will monitor the directory `dataDirectory` and process any files created in that directory (files written in nested directories not supported). Note that @@ -494,7 +626,7 @@ methods for creating DStreams from files and Akka actors as input sources. For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. -- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details. +- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](streaming-custom-receivers.html#implementing-and-using-a-custom-actor-based-receiver) for more details. - **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. @@ -684,13 +816,30 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction); {% endhighlight %} + +
    + +{% highlight python %} +def updateFunction(newValues, runningCount): + if runningCount is None: + runningCount = 0 + return sum(newValues, runningCount) # add the new values with the previous running count to get the new count +{% endhighlight %} + +This is applied on a DStream containing words (say, the `pairs` DStream containing `(word, +1)` pairs in the [earlier example](#a-quick-example)). + +{% highlight python %} +runningCounts = pairs.updateStateByKey(updateFunction) +{% endhighlight %} +
    The update function will be called for each word, with `newValues` having a sequence of 1's (from the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete Scala code, take a look at the example -[StatefulNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala). +[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py). #### Transform Operation {:.no_toc} @@ -732,6 +881,15 @@ JavaPairDStream cleanedDStream = wordCounts.transform( }); {% endhighlight %} + +
    + +{% highlight python %} +spamInfoRDD = sc.pickleFile(...) # RDD containing spam information + +# join data stream with spam information to do data cleaning +cleanedDStream = wordCounts.transform(lambda rdd: rdd.join(spamInfoRDD).filter(...)) +{% endhighlight %}
    @@ -793,6 +951,14 @@ Function2 reduceFunc = new Function2 windowedWordCounts = pairs.reduceByKeyAndWindow(reduceFunc, new Duration(30000), new Duration(10000)); {% endhighlight %} + +
    + +{% highlight python %} +# Reduce last 30 seconds of data, every 10 seconds +windowedWordCounts = pairs.reduceByKeyAndWindow(lambda x, y: x + y, lambda x, y: x - y, 30, 10) +{% endhighlight %} +
    @@ -860,6 +1026,7 @@ see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions). For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html). +For the Python API, see [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) *** @@ -872,9 +1039,12 @@ Currently, the following output operations are defined: - + + This is useful for development and debugging. +
    + PS: called pprint() in Python) + @@ -915,17 +1085,41 @@ For this purpose, a developer may inadvertantly try creating a connection object the Spark driver, but try to use it in a Spark worker to save records in the RDDs. For example (in Scala), +
    +
    + +{% highlight scala %} dstream.foreachRDD(rdd => { val connection = createNewConnection() // executed at the driver rdd.foreach(record => { connection.send(record) // executed at the worker }) }) +{% endhighlight %} - This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. +
    +
    + +{% highlight python %} +def sendRecord(rdd): + connection = createNewConnection() # executed at the driver + rdd.foreach(lambda record: connection.send(record)) + connection.close() + +dstream.foreachRDD(sendRecord) +{% endhighlight %} + +
    +
    + + This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. - However, this can lead to another common mistake - creating a new connection for every record. For example, +
    +
    + +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreach(record => { val connection = createNewConnection() @@ -933,9 +1127,28 @@ For example (in Scala), connection.close() }) }) +{% endhighlight %} - Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. +
    +
    + +{% highlight python %} +def sendRecord(record): + connection = createNewConnection() + connection.send(record) + connection.close() + +dstream.foreachRDD(lambda rdd: rdd.foreach(sendRecord)) +{% endhighlight %} +
    +
    + + Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. + +
    +
    +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreachPartition(partitionOfRecords => { val connection = createNewConnection() @@ -943,13 +1156,31 @@ For example (in Scala), connection.close() }) }) +{% endhighlight %} +
    + +
    +{% highlight python %} +def sendPartition(iter): + connection = createNewConnection() + for record in iter: + connection.send(record) + connection.close() + +dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition)) +{% endhighlight %} +
    +
    - This amortizes the connection creation overheads over many records. + This amortizes the connection creation overheads over many records. - Finally, this can be further optimized by reusing connection objects across multiple RDDs/batches. One can maintain a static pool of connection objects than can be reused as RDDs of multiple batches are pushed to the external system, thus further reducing the overheads. - + +
    +
    +{% highlight scala %} dstream.foreachRDD(rdd => { rdd.foreachPartition(partitionOfRecords => { // ConnectionPool is a static, lazily initialized pool of connections @@ -958,8 +1189,25 @@ For example (in Scala), ConnectionPool.returnConnection(connection) // return to the pool for future reuse }) }) +{% endhighlight %} +
    - Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. +
    +{% highlight python %} +def sendPartition(iter): + # ConnectionPool is a static, lazily initialized pool of connections + connection = ConnectionPool.getConnection() + for record in iter: + connection.send(record) + # return to the pool for future reuse + ConnectionPool.returnConnection(connection) + +dstream.foreachRDD(lambda rdd: rdd.foreachPartition(sendPartition)) +{% endhighlight %} +
    +
    + +Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. ##### Other points to remember: @@ -1376,6 +1624,44 @@ You can also explicitly create a `JavaStreamingContext` from the checkpoint data the computation by using `new JavaStreamingContext(checkpointDirectory)`. +
    + +This behavior is made simple by using `StreamingContext.getOrCreate`. This is used as follows. + +{% highlight python %} +# Function to create and setup a new StreamingContext +def functionToCreateContext(): + sc = SparkContext(...) # new context + ssc = new StreamingContext(...) + lines = ssc.socketTextStream(...) # create DStreams + ... + ssc.checkpoint(checkpointDirectory) # set checkpoint directory + return ssc + +# Get StreamingContext from checkpoint data or create a new one +context = StreamingContext.getOrCreate(checkpointDirectory, functionToCreateContext) + +# Do additional setup on context that needs to be done, +# irrespective of whether it is being started or restarted +context. ... + +# Start the context +context.start() +context.awaitTermination() +{% endhighlight %} + +If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. +If the directory does not exist (i.e., running for the first time), +then the function `functionToCreateContext` will be called to create a new +context and set up the DStreams. See the Python example +[recoverable_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming/recoverable_network_wordcount.py). +This example appends the word counts of network data into a file. + +You can also explicitly create a `StreamingContext` from the checkpoint data and start the + computation by using `StreamingContext.getOrCreate(checkpointDirectory, None)`. + +
    + **Note**: If Spark Streaming and/or the Spark Streaming program is recompiled, @@ -1572,7 +1858,11 @@ package and renamed for better clarity. [TwitterUtils](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html), [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) + - Python docs + * [StreamingContext](api/python/pyspark.streaming.html#pyspark.streaming.StreamingContext) + * [DStream](api/python/pyspark.streaming.html#pyspark.streaming.DStream) * More examples in [Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming) and [Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming) + and [Python] ({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/streaming) * [Paper](http://www.eecs.berkeley.edu/Pubs/TechRpts/2012/EECS-2012-259.pdf) and [video](http://youtu.be/g171ndOHgJ0) describing Spark Streaming. diff --git a/docs/tuning.md b/docs/tuning.md index 8fb2a0433b1a8..9b5c9adac6a4f 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -47,24 +47,11 @@ registration requirement, but we recommend trying it in any network-intensive ap Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library. -To register your own custom classes with Kryo, create a public class that extends -[`org.apache.spark.serializer.KryoRegistrator`](api/scala/index.html#org.apache.spark.serializer.KryoRegistrator) and set the -`spark.kryo.registrator` config property to point to it, as follows: +To register your own custom classes with Kryo, use the `registerKryoClasses` method. {% highlight scala %} -import com.esotericsoftware.kryo.Kryo -import org.apache.spark.serializer.KryoRegistrator - -class MyRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[MyClass1]) - kryo.register(classOf[MyClass2]) - } -} - val conf = new SparkConf().setMaster(...).setAppName(...) -conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") -conf.set("spark.kryo.registrator", "mypackage.MyRegistrator") +conf.registerKryoClasses(Seq(classOf[MyClass1], classOf[MyClass2])) val sc = new SparkContext(conf) {% endhighlight %} diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 31f9771223e51..4aa908242eeaa 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -18,5 +18,9 @@ # limitations under the License. # -cd "`dirname $0`" -PYTHONPATH="./third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" python ./spark_ec2.py "$@" +# Preserve the user's CWD so that relative paths are passed correctly to +#+ the underlying Python script. +SPARK_EC2_DIR="$(dirname $0)" + +PYTHONPATH="${SPARK_EC2_DIR}/third_party/boto-2.4.1.zip/boto-2.4.1:$PYTHONPATH" \ + python "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 941dfb988b9fb..742c7765e728e 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -32,6 +32,7 @@ import tempfile import time import urllib2 +import warnings from optparse import OptionParser from sys import stderr import boto @@ -39,9 +40,11 @@ from boto import ec2 DEFAULT_SPARK_VERSION = "1.1.0" +SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) +MESOS_SPARK_EC2_BRANCH = "v4" # A URL prefix from which to fetch AMI information -AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" +AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH) class UsageError(Exception): @@ -61,8 +64,8 @@ def parse_args(): "-s", "--slaves", type="int", default=1, help="Number of slaves to launch (default: %default)") parser.add_option( - "-w", "--wait", type="int", default=120, - help="Seconds to wait for nodes to start (default: %default)") + "-w", "--wait", type="int", + help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") parser.add_option( "-k", "--key-pair", help="Key pair to use on instances") @@ -83,7 +86,7 @@ def parse_args(): "-z", "--zone", default="", help="Availability zone to launch instances in, or 'all' to spread " + "slaves across multiple (an additional $0.01/Gb for bandwidth" + - "between zones applies)") + "between zones applies) (default: a single zone chosen at random)") parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") parser.add_option( "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, @@ -135,7 +138,7 @@ def parse_args(): help="The SSH user you want to connect as (default: %default)") parser.add_option( "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created.") + help="When destroying a cluster, delete the security groups that were created") parser.add_option( "--use-existing-master", action="store_true", default=False, help="Launch fresh slaves, but use an existing stopped master if possible") @@ -149,9 +152,6 @@ def parse_args(): parser.add_option( "--user-data", type="string", default="", help="Path to a user-data file (most AMI's interpret this as an initialization script)") - parser.add_option( - "--security-group-prefix", type="string", default=None, - help="Use this prefix for the security group rather than the cluster name.") parser.add_option( "--authorized-address", type="string", default="0.0.0.0/0", help="Address to authorize on created security groups (default: %default)") @@ -195,18 +195,6 @@ def get_or_make_group(conn, name): return conn.create_security_group(name, "Spark EC2 group") -# Wait for a set of launched instances to exit the "pending" state -# (i.e. either to start running or to fail and be terminated) -def wait_for_instances(conn, instances): - while True: - for i in instances: - i.update() - if len([i for i in instances if i.state == 'pending']) > 0: - time.sleep(5) - else: - return - - # Check whether a given EC2 instance object is in a state we consider active, # i.e. not terminating or terminated. We count both stopping and stopped as # active since we can restart stopped clusters. @@ -314,12 +302,8 @@ def launch_cluster(conn, opts, cluster_name): user_data_content = user_data_file.read() print "Setting up security groups..." - if opts.security_group_prefix is None: - master_group = get_or_make_group(conn, cluster_name + "-master") - slave_group = get_or_make_group(conn, cluster_name + "-slaves") - else: - master_group = get_or_make_group(conn, opts.security_group_prefix + "-master") - slave_group = get_or_make_group(conn, opts.security_group_prefix + "-slaves") + master_group = get_or_make_group(conn, cluster_name + "-master") + slave_group = get_or_make_group(conn, cluster_name + "-slaves") authorized_address = opts.authorized_address if master_group.rules == []: # Group was just now created master_group.authorize(src_group=master_group) @@ -344,11 +328,12 @@ def launch_cluster(conn, opts, cluster_name): slave_group.authorize('tcp', 60060, 60060, authorized_address) slave_group.authorize('tcp', 60075, 60075, authorized_address) - # Check if instances are already running with the cluster name + # Check if instances are already running in our groups existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, die_on_error=False) if existing_slaves or (existing_masters and not opts.use_existing_master): - print >> stderr, ("ERROR: There are already instances for name: %s " % cluster_name) + print >> stderr, ("ERROR: There are already instances running in " + + "group %s or %s" % (master_group.name, slave_group.name)) sys.exit(1) # Figure out Spark AMI @@ -422,13 +407,9 @@ def launch_cluster(conn, opts, cluster_name): for r in reqs: id_to_req[r.id] = r active_instance_ids = [] - outstanding_request_ids = [] for i in my_req_ids: - if i in id_to_req: - if id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) - else: - outstanding_request_ids.append(i) + if i in id_to_req and id_to_req[i].state == "active": + active_instance_ids.append(id_to_req[i].instance_id) if len(active_instance_ids) == opts.slaves: print "All %d slaves granted" % opts.slaves reservations = conn.get_all_instances(active_instance_ids) @@ -437,8 +418,8 @@ def launch_cluster(conn, opts, cluster_name): slave_nodes += r.instances break else: - print "%d of %d slaves granted, waiting longer for request ids including %s" % ( - len(active_instance_ids), opts.slaves, outstanding_request_ids[0:10]) + print "%d of %d slaves granted, waiting longer" % ( + len(active_instance_ids), opts.slaves) except: print "Canceling spot instance requests" conn.cancel_spot_instance_requests(my_req_ids) @@ -497,59 +478,34 @@ def launch_cluster(conn, opts, cluster_name): # Give the instances descriptive names for master in master_nodes: - name = '{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id) - tag_instance(master, name) - + master.add_tag( + key='Name', + value='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) for slave in slave_nodes: - name = '{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id) - tag_instance(slave, name) + slave.add_tag( + key='Name', + value='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) # Return all the instances return (master_nodes, slave_nodes) -def tag_instance(instance, name): - for i in range(0, 5): - try: - instance.add_tag(key='Name', value=name) - break - except: - print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) - if i == 5: - raise "Error - failed max attempts to add name tag" - time.sleep(5) - # Get the EC2 instances in an existing cluster if available. # Returns a tuple of lists of EC2 instance objects for the masters and slaves def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): print "Searching for existing cluster " + cluster_name + "..." - # Search all the spot instance requests, and copy any tags from the spot - # instance request to the cluster. - spot_instance_requests = conn.get_all_spot_instance_requests() - for req in spot_instance_requests: - if req.state != u'active': - continue - name = req.tags.get(u'Name', "") - if name.startswith(cluster_name): - reservations = conn.get_all_instances(instance_ids=[req.instance_id]) - for res in reservations: - active = [i for i in res.instances if is_active(i)] - for instance in active: - if instance.tags.get(u'Name') is None: - tag_instance(instance, name) - # Now proceed to detect master and slaves instances. reservations = conn.get_all_instances() master_nodes = [] slave_nodes = [] for res in reservations: active = [i for i in res.instances if is_active(i)] for inst in active: - name = inst.tags.get(u'Name', "") - if name.startswith(cluster_name + "-master"): + group_names = [g.name for g in inst.groups] + if group_names == [cluster_name + "-master"]: master_nodes.append(inst) - elif name.startswith(cluster_name + "-slave"): + elif group_names == [cluster_name + "-slaves"]: slave_nodes.append(inst) if any((master_nodes, slave_nodes)): print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) @@ -557,12 +513,12 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in with name " + \ - cluster_name + "-master" + print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master" else: print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) + # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. @@ -594,10 +550,23 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): # NOTE: We should clone the repository before running deploy_files to # prevent ec2-variables.sh from being overwritten - ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v3") + ssh( + host=master, + opts=opts, + command="rm -rf spark-ec2" + + " && " + + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH) + ) print "Deploying files to master..." - deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules) + deploy_files( + conn=conn, + root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", + opts=opts, + master_nodes=master_nodes, + slave_nodes=slave_nodes, + modules=modules + ) print "Running setup on master..." setup_spark_cluster(master, opts) @@ -619,14 +588,64 @@ def setup_spark_cluster(master, opts): print "Ganglia started at http://%s:5080/ganglia" % master -# Wait for a whole cluster (masters, slaves and ZooKeeper) to start up -def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): - print "Waiting for instances to start up..." - time.sleep(5) - wait_for_instances(conn, master_nodes) - wait_for_instances(conn, slave_nodes) - print "Waiting %d more seconds..." % wait_secs - time.sleep(wait_secs) +def is_ssh_available(host, opts): + "Checks if SSH is available on the host." + try: + with open(os.devnull, 'w') as devnull: + ret = subprocess.check_call( + ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', + '%s@%s' % (opts.user, host), stringify_command('true')], + stdout=devnull, + stderr=devnull + ) + return ret == 0 + except subprocess.CalledProcessError as e: + return False + + +def is_cluster_ssh_available(cluster_instances, opts): + for i in cluster_instances: + if not is_ssh_available(host=i.ip_address, opts=opts): + return False + else: + return True + + +def wait_for_cluster_state(cluster_instances, cluster_state, opts): + """ + cluster_instances: a list of boto.ec2.instance.Instance + cluster_state: a string representing the desired state of all the instances in the cluster + value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as + 'running', 'terminated', etc. + (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) + """ + sys.stdout.write( + "Waiting for all instances in cluster to enter '{s}' state.".format(s=cluster_state) + ) + sys.stdout.flush() + + num_attempts = 0 + + while True: + time.sleep(3 * num_attempts) + + for i in cluster_instances: + s = i.update() # capture output to suppress print to screen in newer versions of boto + + if cluster_state == 'ssh-ready': + if all(i.state == 'running' for i in cluster_instances) and \ + is_cluster_ssh_available(cluster_instances, opts): + break + else: + if all(i.state == cluster_state for i in cluster_instances): + break + + num_attempts += 1 + + sys.stdout.write(".") + sys.stdout.flush() + + sys.stdout.write("\n") # Get number of local disks available for a given EC2 instance type. @@ -684,6 +703,8 @@ def get_num_disks(instance_type): # cluster (e.g. lists of masters and slaves). Files are only deployed to # the first master instance in the cluster, and we expect the setup # script to be run on that instance to copy them to other nodes. +# +# root_dir should be an absolute path to the directory with the files we want to deploy. def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): active_master = master_nodes[0].public_dns_name @@ -868,6 +889,16 @@ def real_main(): (opts, action, cluster_name) = parse_args() # Input parameter validation + if opts.wait is not None: + # NOTE: DeprecationWarnings are silent in 2.7+ by default. + # To show them, run Python with the -Wdefault switch. + # See: https://docs.python.org/3.5/whatsnew/2.7.html + warnings.warn( + "This option is deprecated and has no effect. " + "spark-ec2 automatically waits as long as necessary for clusters to startup.", + DeprecationWarning + ) + if opts.ebs_vol_num > 8: print >> stderr, "ebs-vol-num cannot be greater than 8" sys.exit(1) @@ -890,7 +921,11 @@ def real_main(): (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) else: (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, True) elif action == "destroy": @@ -914,12 +949,12 @@ def real_main(): # Delete security groups as well if opts.delete_groups: print "Deleting security groups (this will take some time)..." - if opts.security_group_prefix is None: - group_names = [cluster_name + "-master", cluster_name + "-slaves"] - else: - group_names = [opts.security_group_prefix + "-master", - opts.security_group_prefix + "-slaves"] - + group_names = [cluster_name + "-master", cluster_name + "-slaves"] + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='terminated', + opts=opts + ) attempt = 1 while attempt <= 3: print "Attempt %d" % attempt @@ -1019,7 +1054,11 @@ def real_main(): for inst in master_nodes: if inst.state not in ["shutting-down", "terminated"]: inst.start() - wait_for_cluster(conn, opts.wait, master_nodes, slave_nodes) + wait_for_cluster_state( + cluster_instances=(master_nodes + slave_nodes), + cluster_state='ssh-ready', + opts=opts + ) setup_cluster(conn, master_nodes, slave_nodes, opts, False) else: diff --git a/examples/pom.xml b/examples/pom.xml index eb49a0e5af22d..8713230e1e8ed 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -34,24 +34,6 @@ Spark Project Exampleshttp://spark.apache.org/ - - - kinesis-asl - - - org.apache.spark - spark-streaming-kinesis-asl_${scala.binary.version} - ${project.version} - - - org.apache.httpcomponents - httpclient - ${commons.httpclient.version} - - - - - @@ -102,12 +84,12 @@ org.apache.spark - spark-streaming-kafka_${scala.binary.version} + spark-streaming-flume_${scala.binary.version} ${project.version} org.apache.spark - spark-streaming-flume_${scala.binary.version} + spark-streaming-mqtt_${scala.binary.version} ${project.version} @@ -116,45 +98,151 @@ ${project.version} - org.apache.spark - spark-streaming-mqtt_${scala.binary.version} - ${project.version} + org.eclipse.jetty + jetty-server - - org.apache.hbase - hbase - ${hbase.version} - - - asm - asm - - - org.jboss.netty - netty - - + + org.apache.hbase + hbase-testing-util + ${hbase.version} + + + + org.apache.hbase + hbase-annotations + + + org.jruby + jruby-complete + + + + + org.apache.hbase + hbase-protocol + ${hbase.version} + + + org.apache.hbase + hbase-common + ${hbase.version} + + + + org.apache.hbase + hbase-annotations + + + + + org.apache.hbase + hbase-client + ${hbase.version} + + + + org.apache.hbase + hbase-annotations + + io.netty netty - - - commons-logging - commons-logging - - - org.jruby - jruby-complete - - - + + + + + org.apache.hbase + hbase-server + ${hbase.version} + + + org.apache.hadoop + hadoop-core + + + org.apache.hadoop + hadoop-client + + + org.apache.hadoop + hadoop-mapreduce-client-jobclient + + + org.apache.hadoop + hadoop-mapreduce-client-core + + + org.apache.hadoop + hadoop-auth + + + + org.apache.hbase + hbase-annotations + + + org.apache.hadoop + hadoop-annotations + + + org.apache.hadoop + hadoop-hdfs + + + org.apache.hbase + hbase-hadoop1-compat + + + org.apache.commons + commons-math + + + com.sun.jersey + jersey-core + + + org.slf4j + slf4j-api + + + com.sun.jersey + jersey-server + + + com.sun.jersey + jersey-core + + + com.sun.jersey + jersey-json + + + + commons-io + commons-io + + + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + + + org.apache.hbase + hbase-hadoop-compat + ${hbase.version} + test-jar + test + - org.eclipse.jetty - jetty-server + org.apache.commons + commons-math3 com.twitter algebird-core_${scala.binary.version} - 0.1.11 + 0.8.1 org.scalatest @@ -268,6 +356,10 @@ com.google.common.base.Optional** + + org.apache.commons.math3 + org.spark-project.commons.math3 + @@ -284,4 +376,83 @@ + + + kinesis-asl + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + org.apache.httpcomponents + httpclient + ${commons.httpclient.version} + + + + + hbase-hadoop2 + + + hbase.profile + hadoop2 + + + + 0.98.7-hadoop2 + + + + hbase-hadoop1 + + + !hbase.profile + + + + 0.98.7-hadoop1 + + + + + scala-2.10 + + !scala-2.11 + + + + org.apache.spark + spark-streaming-kafka_${scala.binary.version} + ${project.version} + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + scala-2.10/src/main/scala + scala-2.10/src/main/java + + + + + + + + + diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java similarity index 100% rename from examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java rename to examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala similarity index 100% rename from examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala rename to examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index 6c177de359b60..31a79ddd3fff1 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -30,12 +30,25 @@ /** * Logistic regression based classification. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ public final class JavaHdfsLR { private static final int D = 10; // Number of dimensions private static final Random rand = new Random(42); + static void showWarning() { + String warning = "WARN: This is a naive implementation of Logistic Regression " + + "and is given as an example!\n" + + "Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD " + + "or org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS " + + "for more conventional use."; + System.err.println(warning); + } + static class DataPoint implements Serializable { DataPoint(double[] x, double y) { this.x = x; @@ -109,6 +122,8 @@ public static void main(String[] args) { System.exit(1); } + showWarning(); + SparkConf sparkConf = new SparkConf().setAppName("JavaHdfsLR"); JavaSparkContext sc = new JavaSparkContext(sparkConf); JavaRDD lines = sc.textFile(args[0]); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index c22506491fbff..a5db8accdf138 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -45,10 +45,21 @@ * URL neighbor URL * ... * where URL and their neighbors are separated by space(s). + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.graphx.lib.PageRank */ public final class JavaPageRank { private static final Pattern SPACES = Pattern.compile("\\s+"); + static void showWarning() { + String warning = "WARN: This is a naive implementation of PageRank " + + "and is given as an example! \n" + + "Please use the PageRank implementation found in " + + "org.apache.spark.graphx.lib.PageRank for more conventional use."; + System.err.println(warning); + } + private static class Sum implements Function2 { @Override public Double call(Double a, Double b) { @@ -62,6 +73,8 @@ public static void main(String[] args) throws Exception { System.exit(1); } + showWarning(); + SparkConf sparkConf = new SparkConf().setAppName("JavaPageRank"); JavaSparkContext ctx = new JavaSparkContext(sparkConf); diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java new file mode 100644 index 0000000000000..e68ec74c3ed54 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusTrackerDemo.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkJobInfo; +import org.apache.spark.SparkStageInfo; +import org.apache.spark.api.java.JavaFutureAction; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; + +import java.util.Arrays; +import java.util.List; + +/** + * Example of using Spark's status APIs from Java. + */ +public final class JavaStatusTrackerDemo { + + public static final String APP_NAME = "JavaStatusAPIDemo"; + + public static final class IdentityWithDelay implements Function { + @Override + public T call(T x) throws Exception { + Thread.sleep(2 * 1000); // 2 seconds + return x; + } + } + + public static void main(String[] args) throws Exception { + SparkConf sparkConf = new SparkConf().setAppName(APP_NAME); + final JavaSparkContext sc = new JavaSparkContext(sparkConf); + + // Example of implementing a progress reporter for a simple job. + JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map( + new IdentityWithDelay()); + JavaFutureAction> jobFuture = rdd.collectAsync(); + while (!jobFuture.isDone()) { + Thread.sleep(1000); // 1 second + List jobIds = jobFuture.jobIds(); + if (jobIds.isEmpty()) { + continue; + } + int currentJobId = jobIds.get(jobIds.size() - 1); + SparkJobInfo jobInfo = sc.statusTracker().getJobInfo(currentJobId); + SparkStageInfo stageInfo = sc.statusTracker().getStageInfo(jobInfo.stageIds()[0]); + System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() + + " active, " + stageInfo.numCompletedTasks() + " complete"); + } + + System.out.println("Job results are: " + jobFuture.get()); + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java new file mode 100644 index 0000000000000..22ba68d8c354c --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +import java.util.List; + +import com.google.common.collect.Lists; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.Pipeline; +import org.apache.spark.ml.PipelineModel; +import org.apache.spark.ml.PipelineStage; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.HashingTF; +import org.apache.spark.ml.feature.Tokenizer; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import org.apache.spark.sql.api.java.Row; +import org.apache.spark.SparkConf; + +/** + * A simple text classification pipeline that recognizes "spark" from input text. It uses the Java + * bean classes {@link LabeledDocument} and {@link Document} defined in the Scala counterpart of + * this example {@link SimpleTextClassificationPipeline}. Run with + *
    + * bin/run-example ml.JavaSimpleTextClassificationPipeline
    + * 
    + */ +public class JavaSimpleTextClassificationPipeline { + + public static void main(String[] args) { + SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline"); + JavaSparkContext jsc = new JavaSparkContext(conf); + JavaSQLContext jsql = new JavaSQLContext(jsc); + + // Prepare training documents, which are labeled. + List localTraining = Lists.newArrayList( + new LabeledDocument(0L, "a b c d e spark", 1.0), + new LabeledDocument(1L, "b d", 0.0), + new LabeledDocument(2L, "spark f g h", 1.0), + new LabeledDocument(3L, "hadoop mapreduce", 0.0)); + JavaSchemaRDD training = + jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + Tokenizer tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words"); + HashingTF hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol()) + .setOutputCol("features"); + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {tokenizer, hashingTF, lr}); + + // Fit the pipeline to training documents. + PipelineModel model = pipeline.fit(training); + + // Prepare test documents, which are unlabeled. + List localTest = Lists.newArrayList( + new Document(4L, "spark i j k"), + new Document(5L, "l m n"), + new Document(6L, "mapreduce spark"), + new Document(7L, "apache hadoop")); + JavaSchemaRDD test = + jsql.applySchema(jsc.parallelize(localTest), Document.class); + + // Make predictions on test documents. + model.transform(test).registerAsTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + for (Row r: predictions.collect()) { + System.out.println(r); + } + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java index 8d381d4e0a943..95a430f1da234 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java @@ -32,7 +32,7 @@ import scala.Tuple2; /** - * Example using MLLib ALS from Java. + * Example using MLlib ALS from Java. */ public final class JavaALS { diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java new file mode 100644 index 0000000000000..4a5ac404ea5ea --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaGradientBoostedTreesRunner.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib; + +import scala.Tuple2; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.tree.GradientBoostedTrees; +import org.apache.spark.mllib.tree.configuration.BoostingStrategy; +import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel; +import org.apache.spark.mllib.util.MLUtils; + +/** + * Classification and regression using gradient-boosted decision trees. + */ +public final class JavaGradientBoostedTreesRunner { + + private static void usage() { + System.err.println("Usage: JavaGradientBoostedTreesRunner " + + " "); + System.exit(-1); + } + + public static void main(String[] args) { + String datapath = "data/mllib/sample_libsvm_data.txt"; + String algo = "Classification"; + if (args.length >= 1) { + datapath = args[0]; + } + if (args.length >= 2) { + algo = args[1]; + } + if (args.length > 2) { + usage(); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaGradientBoostedTreesRunner"); + JavaSparkContext sc = new JavaSparkContext(sparkConf); + + JavaRDD data = MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD().cache(); + + // Set parameters. + // Note: All features are treated as continuous. + BoostingStrategy boostingStrategy = BoostingStrategy.defaultParams(algo); + boostingStrategy.setNumIterations(10); + boostingStrategy.treeStrategy().setMaxDepth(5); + + if (algo.equals("Classification")) { + // Compute the number of classes from the data. + Integer numClasses = data.map(new Function() { + @Override public Double call(LabeledPoint p) { + return p.label(); + } + }).countByValue().size(); + boostingStrategy.treeStrategy().setNumClassesForClassification(numClasses); + + // Train a GradientBoosting model for classification. + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainErr = + 1.0 * predictionAndLabel.filter(new Function, Boolean>() { + @Override public Boolean call(Tuple2 pl) { + return !pl._1().equals(pl._2()); + } + }).count() / data.count(); + System.out.println("Training error: " + trainErr); + System.out.println("Learned classification tree model:\n" + model); + } else if (algo.equals("Regression")) { + // Train a GradientBoosting model for classification. + final GradientBoostedTreesModel model = GradientBoostedTrees.train(data, boostingStrategy); + + // Evaluate model on training instances and compute training error + JavaPairRDD predictionAndLabel = + data.mapToPair(new PairFunction() { + @Override public Tuple2 call(LabeledPoint p) { + return new Tuple2(model.predict(p.features()), p.label()); + } + }); + Double trainMSE = + predictionAndLabel.map(new Function, Double>() { + @Override public Double call(Tuple2 pl) { + Double diff = pl._1() - pl._2(); + return diff * diff; + } + }).reduce(new Function2() { + @Override public Double call(Double a, Double b) { + return a + b; + } + }) / data.count(); + System.out.println("Training Mean Squared Error: " + trainMSE); + System.out.println("Learned regression tree model:\n" + model); + } else { + usage(); + } + + sc.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java index f796123a25727..e575eedeb465c 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java @@ -30,7 +30,7 @@ import org.apache.spark.mllib.linalg.Vectors; /** - * Example using MLLib KMeans from Java. + * Example using MLlib KMeans from Java. */ public final class JavaKMeans { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java index 5622df5ce03ff..99df259b4e8e6 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java @@ -57,7 +57,7 @@ public class JavaCustomReceiver extends Receiver { public static void main(String[] args) { if (args.length < 2) { - System.err.println("Usage: JavaNetworkWordCount "); + System.err.println("Usage: JavaCustomReceiver "); System.exit(1); } @@ -70,7 +70,7 @@ public static void main(String[] args) { // Create a input stream with the custom receiver on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') JavaReceiverInputDStream lines = ssc.receiverStream( - new JavaCustomReceiver(args[1], Integer.parseInt(args[2]))); + new JavaCustomReceiver(args[0], Integer.parseInt(args[1]))); JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java index 45bcedebb4117..3e9f0f4b8f127 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaNetworkWordCount.java @@ -25,7 +25,7 @@ import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.api.java.StorageLevels; -import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; @@ -35,8 +35,9 @@ /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + * * Usage: JavaNetworkWordCount - * and describe the TCP server that Spark Streaming would connect to receive data. + * and describe the TCP server that Spark Streaming would connect to receive data. * * To run this on your local machine, you need to first run a Netcat server * `$ nc -lk 9999` @@ -56,7 +57,7 @@ public static void main(String[] args) { // Create the context with a 1 second batch size SparkConf sparkConf = new SparkConf().setAppName("JavaNetworkWordCount"); - JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, new Duration(1000)); + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); // Create a JavaReceiverInputDStream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java new file mode 100644 index 0000000000000..bceda97f058ea --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.regex.Pattern; + +import scala.Tuple2; +import com.google.common.collect.Lists; +import com.google.common.io.Files; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaStreamingContextFactory; + +/** + * Counts words in text encoded with UTF8 received from the network every second. + * + * Usage: JavaRecoverableNetworkWordCount + * and describe the TCP server that Spark Streaming would connect to receive + * data. directory to HDFS-compatible file system which checkpoint data + * file to which the word counts will be appended + * + * and must be absolute paths + * + * To run this on your local machine, you need to first run a Netcat server + * + * `$ nc -lk 9999` + * + * and run the example as + * + * `$ ./bin/run-example org.apache.spark.examples.streaming.JavaRecoverableNetworkWordCount \ + * localhost 9999 ~/checkpoint/ ~/out` + * + * If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create + * a new StreamingContext (will print "Creating new context" to the console). Otherwise, if + * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from + * the checkpoint data. + * + * Refer to the online documentation for more details. + */ +public final class JavaRecoverableNetworkWordCount { + private static final Pattern SPACE = Pattern.compile(" "); + + private static JavaStreamingContext createContext(String ip, + int port, + String checkpointDirectory, + String outputPath) { + + // If you do not see this printed, that means the StreamingContext has been loaded + // from the new checkpoint + System.out.println("Creating new context"); + final File outputFile = new File(outputPath); + if (outputFile.exists()) { + outputFile.delete(); + } + SparkConf sparkConf = new SparkConf().setAppName("JavaRecoverableNetworkWordCount"); + // Create the context with a 1 second batch size + JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); + ssc.checkpoint(checkpointDirectory); + + // Create a socket stream on target ip:port and count the + // words in input stream of \n delimited text (eg. generated by 'nc') + JavaReceiverInputDStream lines = ssc.socketTextStream(ip, port); + JavaDStream words = lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String x) { + return Lists.newArrayList(SPACE.split(x)); + } + }); + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + + wordCounts.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaPairRDD rdd, Time time) throws IOException { + String counts = "Counts at time " + time + " " + rdd.collect(); + System.out.println(counts); + System.out.println("Appending to " + outputFile.getAbsolutePath()); + Files.append(counts + "\n", outputFile, Charset.defaultCharset()); + return null; + } + }); + + return ssc; + } + + public static void main(String[] args) { + if (args.length != 4) { + System.err.println("You arguments were " + Arrays.asList(args)); + System.err.println( + "Usage: JavaRecoverableNetworkWordCount \n" + + " . and describe the TCP server that Spark\n" + + " Streaming would connect to receive data. directory to\n" + + " HDFS-compatible file system which checkpoint data file to which\n" + + " the word counts will be appended\n" + + "\n" + + "In local mode, should be 'local[n]' with n > 1\n" + + "Both and must be absolute paths"); + System.exit(1); + } + + final String ip = args[0]; + final int port = Integer.parseInt(args[1]); + final String checkpointDirectory = args[2]; + final String outputPath = args[3]; + JavaStreamingContextFactory factory = new JavaStreamingContextFactory() { + @Override + public JavaStreamingContext create() { + return createContext(ip, port, checkpointDirectory, outputPath); + } + }; + JavaStreamingContext ssc = JavaStreamingContext.getOrCreate(checkpointDirectory, factory); + ssc.start(); + ssc.awaitTermination(); + } +} diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py new file mode 100644 index 0000000000000..540dae785f6ea --- /dev/null +++ b/examples/src/main/python/mllib/dataset_example.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +An example of how to use SchemaRDD as a dataset for ML. Run with:: + bin/spark-submit examples/src/main/python/mllib/dataset_example.py +""" + +import os +import sys +import tempfile +import shutil + +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark.mllib.util import MLUtils +from pyspark.mllib.stat import Statistics + + +def summarize(dataset): + print "schema: %s" % dataset.schema().json() + labels = dataset.map(lambda r: r.label) + print "label average: %f" % labels.mean() + features = dataset.map(lambda r: r.features) + summary = Statistics.colStats(features) + print "features average: %r" % summary.mean() + +if __name__ == "__main__": + if len(sys.argv) > 2: + print >> sys.stderr, "Usage: dataset_example.py " + exit(-1) + sc = SparkContext(appName="DatasetExample") + sqlCtx = SQLContext(sc) + if len(sys.argv) == 2: + input = sys.argv[1] + else: + input = "data/mllib/sample_libsvm_data.txt" + points = MLUtils.loadLibSVMFile(sc, input) + dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache() + summarize(dataset0) + tempdir = tempfile.NamedTemporaryFile(delete=False).name + os.unlink(tempdir) + print "Save dataset as a Parquet file to %s." % tempdir + dataset0.saveAsParquetFile(tempdir) + print "Load it back and summarize it again." + dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache() + summarize(dataset1) + shutil.rmtree(tempdir) diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py new file mode 100644 index 0000000000000..99fef4276a369 --- /dev/null +++ b/examples/src/main/python/mllib/word2vec.py @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# This example uses text8 file from http://mattmahoney.net/dc/text8.zip +# The file was downloadded, unziped and split into multiple lines using +# +# wget http://mattmahoney.net/dc/text8.zip +# unzip text8.zip +# grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines +# This was done so that the example can be run in local mode + + +import sys + +from pyspark import SparkContext +from pyspark.mllib.feature import Word2Vec + +USAGE = ("bin/spark-submit --driver-memory 4g " + "examples/src/main/python/mllib/word2vec.py text8_lines") + +if __name__ == "__main__": + if len(sys.argv) < 2: + print USAGE + sys.exit("Argument for file not provided") + file_path = sys.argv[1] + sc = SparkContext(appName='Word2Vec') + inp = sc.textFile(file_path).map(lambda row: row.split(" ")) + + word2vec = Word2Vec() + model = word2vec.fit(inp) + + synonyms = model.findSynonyms('china', 40) + + for word, cosine_distance in synonyms: + print "{}: {}".format(word, cosine_distance) + sc.stop() diff --git a/examples/src/main/python/pagerank.py b/examples/src/main/python/pagerank.py index b539c4128cdcc..a5f25d78c1146 100755 --- a/examples/src/main/python/pagerank.py +++ b/examples/src/main/python/pagerank.py @@ -15,6 +15,11 @@ # limitations under the License. # +""" +This is an example implementation of PageRank. For more conventional use, +Please refer to PageRank implementation provided by graphx +""" + import re import sys from operator import add @@ -40,6 +45,9 @@ def parseNeighbors(urls): print >> sys.stderr, "Usage: pagerank " exit(-1) + print >> sys.stderr, """WARN: This is a naive implementation of PageRank and is + given as an example! Please refer to PageRank implementation provided by graphx""" + # Initialize the spark context. sc = SparkContext(appName="PythonPageRank") diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index eefa022f1927c..d2c5ca48c6cb8 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -48,7 +48,7 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. - path = os.environ['SPARK_HOME'] + "examples/src/main/resources/people.json" + path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") # Create a SchemaRDD from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py new file mode 100644 index 0000000000000..f7ffb5379681e --- /dev/null +++ b/examples/src/main/python/streaming/hdfs_wordcount.py @@ -0,0 +1,49 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in new text files created in the given directory + Usage: hdfs_wordcount.py + is the directory that Spark Streaming will use to find and read new text files. + + To run this on your local machine on directory `localdir`, run this example + $ bin/spark-submit examples/src/main/python/streaming/hdfs_wordcount.py localdir + + Then create a text file in `localdir` and the words in the file will get counted. +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, "Usage: hdfs_wordcount.py " + exit(-1) + + sc = SparkContext(appName="PythonStreamingHDFSWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.textFileStream(sys.argv[1]) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda x: (x, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/network_wordcount.py b/examples/src/main/python/streaming/network_wordcount.py new file mode 100644 index 0000000000000..cfa9c1ff5bfbc --- /dev/null +++ b/examples/src/main/python/streaming/network_wordcount.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the network every second. + Usage: network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingNetworkWordCount") + ssc = StreamingContext(sc, 1) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .reduceByKey(lambda a, b: a+b) + counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py new file mode 100644 index 0000000000000..fc6827c82bf9b --- /dev/null +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in text encoded with UTF8 received from the network every second. + + Usage: recoverable_network_wordcount.py + and describe the TCP server that Spark Streaming would connect to receive + data. directory to HDFS-compatible file system which checkpoint data + file to which the word counts will be appended + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/recoverable_network_wordcount.py \ + localhost 9999 ~/checkpoint/ ~/out` + + If the directory ~/checkpoint/ does not exist (e.g. running for the first time), it will create + a new StreamingContext (will print "Creating new context" to the console). Otherwise, if + checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from + the checkpoint data. +""" + +import os +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + + +def createContext(host, port, outputPath): + # If you do not see this printed, that means the StreamingContext has been loaded + # from the new checkpoint + print "Creating new context" + if os.path.exists(outputPath): + os.remove(outputPath) + sc = SparkContext(appName="PythonStreamingRecoverableNetworkWordCount") + ssc = StreamingContext(sc, 1) + + # Create a socket stream on target ip:port and count the + # words in input stream of \n delimited text (eg. generated by 'nc') + lines = ssc.socketTextStream(host, port) + words = lines.flatMap(lambda line: line.split(" ")) + wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) + + def echo(time, rdd): + counts = "Counts at time %s %s" % (time, rdd.collect()) + print counts + print "Appending to " + os.path.abspath(outputPath) + with open(outputPath, 'a') as f: + f.write(counts + "\n") + + wordCounts.foreachRDD(echo) + return ssc + +if __name__ == "__main__": + if len(sys.argv) != 5: + print >> sys.stderr, "Usage: recoverable_network_wordcount.py "\ + " " + exit(-1) + host, port, checkpoint, output = sys.argv[1:] + ssc = StreamingContext.getOrCreate(checkpoint, + lambda: createContext(host, int(port), output)) + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/stateful_network_wordcount.py b/examples/src/main/python/streaming/stateful_network_wordcount.py new file mode 100644 index 0000000000000..18a9a5a452ffb --- /dev/null +++ b/examples/src/main/python/streaming/stateful_network_wordcount.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Counts words in UTF8 encoded, '\n' delimited text received from the + network every second. + + Usage: stateful_network_wordcount.py + and describe the TCP server that Spark Streaming + would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \ + localhost 9999` +""" + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, "Usage: stateful_network_wordcount.py " + exit(-1) + sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount") + ssc = StreamingContext(sc, 1) + ssc.checkpoint("checkpoint") + + def updateFunc(new_values, last_sum): + return sum(new_values) + (last_sum or 0) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + running_counts = lines.flatMap(lambda line: line.split(" "))\ + .map(lambda word: (word, 1))\ + .updateStateByKey(updateFunc) + + running_counts.pprint() + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 1f576319b3ca8..3d5259463003d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -17,11 +17,7 @@ package org.apache.spark.examples -import scala.math.sqrt - -import cern.colt.matrix._ -import cern.colt.matrix.linalg._ -import cern.jet.math._ +import org.apache.commons.math3.linear._ /** * Alternating least squares matrix factorization. @@ -30,84 +26,70 @@ import cern.jet.math._ * please refer to org.apache.spark.mllib.recommendation.ALS */ object LocalALS { + // Parameters set through command line arguments var M = 0 // Number of movies var U = 0 // Number of users var F = 0 // Number of features var ITERATIONS = 0 - val LAMBDA = 0.01 // Regularization coefficient - // Some COLT objects - val factory2D = DoubleFactory2D.dense - val factory1D = DoubleFactory1D.dense - val algebra = Algebra.DEFAULT - val blas = SeqBlas.seqBlas - - def generateR(): DoubleMatrix2D = { - val mh = factory2D.random(M, F) - val uh = factory2D.random(U, F) - algebra.mult(mh, algebra.transpose(uh)) + def generateR(): RealMatrix = { + val mh = randomMatrix(M, F) + val uh = randomMatrix(U, F) + mh.multiply(uh.transpose()) } - def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], - us: Array[DoubleMatrix1D]): Double = - { - val r = factory2D.make(M, U) + def rmse(targetR: RealMatrix, ms: Array[RealVector], us: Array[RealVector]): Double = { + val r = new Array2DRowRealMatrix(M, U) for (i <- 0 until M; j <- 0 until U) { - r.set(i, j, blas.ddot(ms(i), us(j))) + r.setEntry(i, j, ms(i).dotProduct(us(j))) } - blas.daxpy(-1, targetR, r) - val sumSqs = r.aggregate(Functions.plus, Functions.square) - sqrt(sumSqs / (M * U)) + val diffs = r.subtract(targetR) + var sumSqs = 0.0 + for (i <- 0 until M; j <- 0 until U) { + val diff = diffs.getEntry(i, j) + sumSqs += diff * diff + } + math.sqrt(sumSqs / (M.toDouble * U.toDouble)) } - def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) + def updateMovie(i: Int, m: RealVector, us: Array[RealVector], R: RealMatrix) : RealVector = { + var XtX: RealMatrix = new Array2DRowRealMatrix(F, F) + var Xty: RealVector = new ArrayRealVector(F) // For each user that rated the movie for (j <- 0 until U) { val u = us(j) // Add u * u^t to XtX - blas.dger(1, u, u, XtX) + XtX = XtX.add(u.outerProduct(u)) // Add u * rating to Xty - blas.daxpy(R.get(i, j), u, Xty) + Xty = Xty.add(u.mapMultiply(R.getEntry(i, j))) } - // Add regularization coefs to diagonal terms + // Add regularization coefficients to diagonal terms for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + XtX.addToEntry(d, d, LAMBDA * U) } // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - solved2D.viewColumn(0) + new CholeskyDecomposition(XtX).getSolver.solve(Xty) } - def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) + def updateUser(j: Int, u: RealVector, ms: Array[RealVector], R: RealMatrix) : RealVector = { + var XtX: RealMatrix = new Array2DRowRealMatrix(F, F) + var Xty: RealVector = new ArrayRealVector(F) // For each movie that the user rated for (i <- 0 until M) { val m = ms(i) // Add m * m^t to XtX - blas.dger(1, m, m, XtX) + XtX = XtX.add(m.outerProduct(m)) // Add m * rating to Xty - blas.daxpy(R.get(i, j), m, Xty) + Xty = Xty.add(m.mapMultiply(R.getEntry(i, j))) } - // Add regularization coefs to diagonal terms + // Add regularization coefficients to diagonal terms for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * M) + XtX.addToEntry(d, d, LAMBDA * M) } // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - solved2D.viewColumn(0) + new CholeskyDecomposition(XtX).getSolver.solve(Xty) } def showWarning() { @@ -135,21 +117,28 @@ object LocalALS { showWarning() - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) + println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS") val R = generateR() // Initialize m and u randomly - var ms = Array.fill(M)(factory1D.random(F)) - var us = Array.fill(U)(factory1D.random(F)) + var ms = Array.fill(M)(randomVector(F)) + var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users for (iter <- 1 to ITERATIONS) { - println("Iteration " + iter + ":") + println(s"Iteration $iter:") ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray println("RMSE = " + rmse(R, ms, us)) println() } } + + private def randomVector(n: Int): RealVector = + new ArrayRealVector(Array.fill(n)(math.random)) + + private def randomMatrix(rows: Int, cols: Int): RealMatrix = + new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) + } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 931faac5463c4..ac2ea35bbd0e0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object LocalFileLR { val D = 10 // Numer of dimensions @@ -41,7 +42,8 @@ object LocalFileLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 2d75b9d2590f8..92a683ad57ea1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -25,7 +25,8 @@ import breeze.linalg.{Vector, DenseVector} * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object LocalLR { val N = 10000 // Number of data points @@ -48,7 +49,8 @@ object LocalLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index fde8ffeedf8b4..6c0ac8013ce34 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -17,11 +17,7 @@ package org.apache.spark.examples -import scala.math.sqrt - -import cern.colt.matrix._ -import cern.colt.matrix.linalg._ -import cern.jet.math._ +import org.apache.commons.math3.linear._ import org.apache.spark._ @@ -32,62 +28,53 @@ import org.apache.spark._ * please refer to org.apache.spark.mllib.recommendation.ALS */ object SparkALS { + // Parameters set through command line arguments var M = 0 // Number of movies var U = 0 // Number of users var F = 0 // Number of features var ITERATIONS = 0 - val LAMBDA = 0.01 // Regularization coefficient - // Some COLT objects - val factory2D = DoubleFactory2D.dense - val factory1D = DoubleFactory1D.dense - val algebra = Algebra.DEFAULT - val blas = SeqBlas.seqBlas - - def generateR(): DoubleMatrix2D = { - val mh = factory2D.random(M, F) - val uh = factory2D.random(U, F) - algebra.mult(mh, algebra.transpose(uh)) + def generateR(): RealMatrix = { + val mh = randomMatrix(M, F) + val uh = randomMatrix(U, F) + mh.multiply(uh.transpose()) } - def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D], - us: Array[DoubleMatrix1D]): Double = - { - val r = factory2D.make(M, U) + def rmse(targetR: RealMatrix, ms: Array[RealVector], us: Array[RealVector]): Double = { + val r = new Array2DRowRealMatrix(M, U) for (i <- 0 until M; j <- 0 until U) { - r.set(i, j, blas.ddot(ms(i), us(j))) + r.setEntry(i, j, ms(i).dotProduct(us(j))) } - blas.daxpy(-1, targetR, r) - val sumSqs = r.aggregate(Functions.plus, Functions.square) - sqrt(sumSqs / (M * U)) + val diffs = r.subtract(targetR) + var sumSqs = 0.0 + for (i <- 0 until M; j <- 0 until U) { + val diff = diffs.getEntry(i, j) + sumSqs += diff * diff + } + math.sqrt(sumSqs / (M.toDouble * U.toDouble)) } - def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D], - R: DoubleMatrix2D) : DoubleMatrix1D = - { + def update(i: Int, m: RealVector, us: Array[RealVector], R: RealMatrix) : RealVector = { val U = us.size - val F = us(0).size - val XtX = factory2D.make(F, F) - val Xty = factory1D.make(F) + val F = us(0).getDimension + var XtX: RealMatrix = new Array2DRowRealMatrix(F, F) + var Xty: RealVector = new ArrayRealVector(F) // For each user that rated the movie for (j <- 0 until U) { val u = us(j) // Add u * u^t to XtX - blas.dger(1, u, u, XtX) + XtX = XtX.add(u.outerProduct(u)) // Add u * rating to Xty - blas.daxpy(R.get(i, j), u, Xty) + Xty = Xty.add(u.mapMultiply(R.getEntry(i, j))) } // Add regularization coefs to diagonal terms for (d <- 0 until F) { - XtX.set(d, d, XtX.get(d, d) + LAMBDA * U) + XtX.addToEntry(d, d, LAMBDA * U) } // Solve it with Cholesky - val ch = new CholeskyDecomposition(XtX) - val Xty2D = factory2D.make(Xty.toArray, F) - val solved2D = ch.solve(Xty2D) - solved2D.viewColumn(0) + new CholeskyDecomposition(XtX).getSolver.solve(Xty) } def showWarning() { @@ -118,7 +105,7 @@ object SparkALS { showWarning() - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) + println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS") val sparkConf = new SparkConf().setAppName("SparkALS") val sc = new SparkContext(sparkConf) @@ -126,21 +113,21 @@ object SparkALS { val R = generateR() // Initialize m and u randomly - var ms = Array.fill(M)(factory1D.random(F)) - var us = Array.fill(U)(factory1D.random(F)) + var ms = Array.fill(M)(randomVector(F)) + var us = Array.fill(U)(randomVector(F)) // Iteratively update movies then users val Rc = sc.broadcast(R) var msb = sc.broadcast(ms) var usb = sc.broadcast(us) for (iter <- 1 to ITERATIONS) { - println("Iteration " + iter + ":") + println(s"Iteration $iter:") ms = sc.parallelize(0 until M, slices) .map(i => update(i, msb.value(i), usb.value, Rc.value)) .collect() msb = sc.broadcast(ms) // Re-broadcast ms because it was updated us = sc.parallelize(0 until U, slices) - .map(i => update(i, usb.value(i), msb.value, algebra.transpose(Rc.value))) + .map(i => update(i, usb.value(i), msb.value, Rc.value.transpose())) .collect() usb = sc.broadcast(us) // Re-broadcast us because it was updated println("RMSE = " + rmse(R, ms, us)) @@ -149,4 +136,11 @@ object SparkALS { sc.stop() } + + private def randomVector(n: Int): RealVector = + new ArrayRealVector(Array.fill(n)(math.random)) + + private def randomMatrix(rows: Int, cols: Int): RealMatrix = + new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random)) + } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 3258510894372..9099c2fcc90b3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -32,7 +32,8 @@ import org.apache.spark.scheduler.InputFormatInfo * Logistic regression based classification. * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkHdfsLR { val D = 10 // Numer of dimensions @@ -54,7 +55,8 @@ object SparkHdfsLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index fc23308fc4adf..257a7d29f922a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -30,7 +30,8 @@ import org.apache.spark._ * Usage: SparkLR [slices] * * This is an example implementation for learning how to use Spark. For more conventional use, - * please refer to org.apache.spark.mllib.classification.LogisticRegression + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkLR { val N = 10000 // Number of data points @@ -53,7 +54,8 @@ object SparkLR { def showWarning() { System.err.println( """WARN: This is a naive implementation of Logistic Regression and is given as an example! - |Please use the LogisticRegression method found in org.apache.spark.mllib.classification + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS |for more conventional use. """.stripMargin) } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 4c7e006da0618..8d092b6506d33 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -28,13 +28,28 @@ import org.apache.spark.{SparkConf, SparkContext} * URL neighbor URL * ... * where URL and their neighbors are separated by space(s). + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to org.apache.spark.graphx.lib.PageRank */ object SparkPageRank { + + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of PageRank and is given as an example! + |Please use the PageRank implementation found in org.apache.spark.graphx.lib.PageRank + |for more conventional use. + """.stripMargin) + } + def main(args: Array[String]) { if (args.length < 1) { System.err.println("Usage: SparkPageRank ") System.exit(1) } + + showWarning() + val sparkConf = new SparkConf().setAppName("PageRank") val iters = if (args.length > 0) args(1).toInt else 10 val ctx = new SparkContext(sparkConf) diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index 96d13612e46dd..4393b99e636b6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -32,11 +32,24 @@ import org.apache.spark.storage.StorageLevel /** * Logistic regression based classification. * This example uses Tachyon to persist rdds during computation. + * + * This is an example implementation for learning how to use Spark. For more conventional use, + * please refer to either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + * org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS based on your needs. */ object SparkTachyonHdfsLR { val D = 10 // Numer of dimensions val rand = new Random(42) + def showWarning() { + System.err.println( + """WARN: This is a naive implementation of Logistic Regression and is given as an example! + |Please use either org.apache.spark.mllib.classification.LogisticRegressionWithSGD or + |org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS + |for more conventional use. + """.stripMargin) + } + case class DataPoint(x: Vector[Double], y: Double) def parsePoint(line: String): DataPoint = { @@ -51,6 +64,9 @@ object SparkTachyonHdfsLR { } def main(args: Array[String]) { + + showWarning() + val inputPath = args(0) val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") val conf = new Configuration() diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala index e06f4dcd54442..e322d4ce5a745 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -18,17 +18,7 @@ package org.apache.spark.examples.bagel import org.apache.spark._ -import org.apache.spark.SparkContext._ -import org.apache.spark.serializer.KryoRegistrator - import org.apache.spark.bagel._ -import org.apache.spark.bagel.Bagel._ - -import scala.collection.mutable.ArrayBuffer - -import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream} - -import com.esotericsoftware.kryo._ class PageRankUtils extends Serializable { def computeWithCombiner(numVertices: Long, epsilon: Double)( @@ -99,13 +89,6 @@ class PRMessage() extends Message[String] with Serializable { } } -class PRKryoRegistrator extends KryoRegistrator { - def registerClasses(kryo: Kryo) { - kryo.register(classOf[PRVertex]) - kryo.register(classOf[PRMessage]) - } -} - class CustomPartitioner(partitions: Int) extends Partitioner { def numPartitions = partitions diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala index e4db3ec51313d..859abedf2a55e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala @@ -38,8 +38,7 @@ object WikipediaPageRank { } val sparkConf = new SparkConf() sparkConf.setAppName("WikipediaPageRank") - sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - sparkConf.set("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) + sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage])) val inputFile = args(0) val threshold = args(1).toDouble diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index c4317a6aec798..828cffb01ca1e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -46,28 +46,15 @@ object Analytics extends Logging { } val options = mutable.Map(optionsList: _*) - def pickPartitioner(v: String): PartitionStrategy = { - // TODO: Use reflection rather than listing all the partitioning strategies here. - v match { - case "RandomVertexCut" => RandomVertexCut - case "EdgePartition1D" => EdgePartition1D - case "EdgePartition2D" => EdgePartition2D - case "CanonicalRandomVertexCut" => CanonicalRandomVertexCut - case _ => throw new IllegalArgumentException("Invalid PartitionStrategy: " + v) - } - } - - val conf = new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") - .set("spark.locality.wait", "100000") + val conf = new SparkConf().set("spark.locality.wait", "100000") + GraphXUtils.registerKryoClasses(conf) val numEPart = options.remove("numEPart").map(_.toInt).getOrElse { println("Set the number of edge partitions using --numEPart.") sys.exit(1) } val partitionStrategy: Option[PartitionStrategy] = options.remove("partStrategy") - .map(pickPartitioner(_)) + .map(PartitionStrategy.fromString(_)) val edgeStorageLevel = options.remove("edgeStorageLevel") .map(StorageLevel.fromString(_)).getOrElse(StorageLevel.MEMORY_ONLY) val vertexStorageLevel = options.remove("vertexStorageLevel") @@ -90,7 +77,7 @@ object Analytics extends Logging { val sc = new SparkContext(conf.setAppName("PageRank(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, - minEdgePartitions = numEPart, + numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel).cache() val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) @@ -107,7 +94,7 @@ object Analytics extends Logging { if (!outFname.isEmpty) { logWarning("Saving pageranks of pages to " + outFname) - pr.map{case (id, r) => id + "\t" + r}.saveAsTextFile(outFname) + pr.map { case (id, r) => id + "\t" + r }.saveAsTextFile(outFname) } sc.stop() @@ -123,13 +110,13 @@ object Analytics extends Logging { val sc = new SparkContext(conf.setAppName("ConnectedComponents(" + fname + ")")) val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname, - minEdgePartitions = numEPart, + numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel).cache() val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)) val cc = ConnectedComponents.run(graph) - println("Components: " + cc.vertices.map{ case (vid,data) => data}.distinct()) + println("Components: " + cc.vertices.map { case (vid, data) => data }.distinct()) sc.stop() case "triangles" => @@ -144,10 +131,10 @@ object Analytics extends Logging { val sc = new SparkContext(conf.setAppName("TriangleCount(" + fname + ")")) val graph = GraphLoader.edgeListFile(sc, fname, canonicalOrientation = true, - minEdgePartitions = numEPart, + numEdgePartitions = numEPart, edgeStorageLevel = edgeStorageLevel, vertexStorageLevel = vertexStorageLevel) - // TriangleCount requires the graph to be partitioned + // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) println("Triangles: " + triangles.vertices.map { diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 5f35a5836462e..3ec20d594b784 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -18,7 +18,7 @@ package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ -import org.apache.spark.graphx.PartitionStrategy +import org.apache.spark.graphx.{GraphXUtils, PartitionStrategy} import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.graphx.util.GraphGenerators import java.io.{PrintWriter, FileOutputStream} @@ -67,7 +67,7 @@ object SynthBenchmark { options.foreach { case ("app", v) => app = v - case ("niter", v) => niter = v.toInt + case ("niters", v) => niter = v.toInt case ("nverts", v) => numVertices = v.toInt case ("numEPart", v) => numEPart = Some(v.toInt) case ("partStrategy", v) => partitionStrategy = Some(PartitionStrategy.fromString(v)) @@ -80,8 +80,7 @@ object SynthBenchmark { val conf = new SparkConf() .setAppName(s"GraphX Synth Benchmark (nverts = $numVertices, app = $app)") - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") + GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext(conf) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala new file mode 100644 index 0000000000000..ee7897d9062d9 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +import scala.beans.BeanInfo + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.feature.{HashingTF, Tokenizer} +import org.apache.spark.sql.SQLContext + +@BeanInfo +case class LabeledDocument(id: Long, text: String, label: Double) + +@BeanInfo +case class Document(id: Long, text: String) + +/** + * A simple text classification pipeline that recognizes "spark" from input text. This is to show + * how to create and configure an ML pipeline. Run with + * {{{ + * bin/run-example ml.SimpleTextClassificationPipeline + * }}} + */ +object SimpleTextClassificationPipeline { + + def main(args: Array[String]) { + val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ + + // Prepare training documents, which are labeled. + val training = sparkContext.parallelize(Seq( + LabeledDocument(0L, "a b c d e spark", 1.0), + LabeledDocument(1L, "b d", 0.0), + LabeledDocument(2L, "spark f g h", 1.0), + LabeledDocument(3L, "hadoop mapreduce", 0.0))) + + // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. + val tokenizer = new Tokenizer() + .setInputCol("text") + .setOutputCol("words") + val hashingTF = new HashingTF() + .setNumFeatures(1000) + .setInputCol(tokenizer.getOutputCol) + .setOutputCol("features") + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(0.01) + val pipeline = new Pipeline() + .setStages(Array(tokenizer, hashingTF, lr)) + + // Fit the pipeline to training documents. + val model = pipeline.fit(training) + + // Prepare test documents, which are unlabeled. + val test = sparkContext.parallelize(Seq( + Document(4L, "spark i j k"), + Document(5L, "l m n"), + Document(6L, "mapreduce spark"), + Document(7L, "apache hadoop"))) + + // Make predictions on test documents. + model.transform(test) + .select('id, 'text, 'score, 'prediction) + .collect() + .foreach(println) + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala new file mode 100644 index 0000000000000..ae6057758d6fc --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AbstractParams.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scala.reflect.runtime.universe._ + +/** + * Abstract class for parameter case classes. + * This overrides the [[toString]] method to print all case class fields by name and value. + * @tparam T Concrete parameter class. + */ +abstract class AbstractParams[T: TypeTag] { + + private def tag: TypeTag[T] = typeTag[T] + + /** + * Finds all case class fields in concrete class instance, and outputs them in JSON-style format: + * { + * [field name]:\t[field value]\n + * [field name]:\t[field value]\n + * ... + * } + */ + override def toString: String = { + val tpe = tag.tpe + val allAccessors = tpe.declarations.collect { + case m: MethodSymbol if m.isCaseAccessor => m + } + val mirror = runtimeMirror(getClass.getClassLoader) + val instanceMirror = mirror.reflect(this) + allAccessors.map { f => + val paramName = f.name.toString + val fieldMirror = instanceMirror.reflectField(f) + val paramValue = fieldMirror.get + s" $paramName:\t$paramValue" + }.mkString("{\n", ",\n", "\n}") + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index a6f78d2441db1..a113653810b93 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -55,7 +55,7 @@ object BinaryClassification { stepSize: Double = 1.0, algorithm: Algorithm = LR, regType: RegType = L2, - regParam: Double = 0.1) + regParam: Double = 0.01) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index d6b2fe430e5a4..e49129c4e7844 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -35,6 +35,7 @@ import org.apache.spark.{SparkConf, SparkContext} object Correlations { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala new file mode 100644 index 0000000000000..cb1abbd18fd4d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Compute the similar columns of a matrix, using cosine similarity. + * + * The input matrix must be stored in row-oriented dense format, one line per row with its entries + * separated by space. For example, + * {{{ + * 0.5 1.0 + * 2.0 3.0 + * 4.0 5.0 + * }}} + * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). + * + * Example invocation: + * + * bin/run-example mllib.CosineSimilarity \ + * --threshold 0.1 data/mllib/sample_svm_data.txt + */ +object CosineSimilarity { + case class Params(inputFile: String = null, threshold: Double = 0.1) + extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("CosineSimilarity") { + head("CosineSimilarity: an example app.") + opt[Double]("threshold") + .required() + .text(s"threshold similarity: to tradeoff computation vs quality estimate") + .action((x, c) => c.copy(threshold = x)) + arg[String]("") + .required() + .text(s"input file, one row per line, space-separated") + .action((x, c) => c.copy(inputFile = x)) + note( + """ + |For example, the following command runs this app on a dataset: + | + | ./bin/spark-submit --class org.apache.spark.examples.mllib.CosineSimilarity \ + | examplesjar.jar \ + | --threshold 0.1 data/mllib/sample_svm_data.txt + """.stripMargin) + } + + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + System.exit(1) + } + } + + def run(params: Params) { + val conf = new SparkConf().setAppName("CosineSimilarity") + val sc = new SparkContext(conf) + + // Load and parse the data file. + val rows = sc.textFile(params.inputFile).map { line => + val values = line.split(' ').map(_.toDouble) + Vectors.dense(values) + }.cache() + val mat = new RowMatrix(rows) + + // Compute similar columns perfectly, with brute force. + val exact = mat.columnSimilarities() + + // Compute similar columns with estimation using DIMSUM + val approx = mat.columnSimilarities(params.threshold) + + val exactEntries = exact.entries.map { case MatrixEntry(i, j, u) => ((i, j), u) } + val approxEntries = approx.entries.map { case MatrixEntry(i, j, v) => ((i, j), v) } + val MAE = exactEntries.leftOuterJoin(approxEntries).values.map { + case (u, Some(v)) => + math.abs(u - v) + case (u, None) => + math.abs(u) + }.mean() + + println(s"Average absolute error in estimate is: $MAE") + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala new file mode 100644 index 0000000000000..f8d83f4ec7327 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.io.File + +import com.google.common.io.Files +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} + +/** + * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * {{{ + * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object DatasetExample { + + case class Params( + input: String = "data/mllib/sample_libsvm_data.txt", + dataFormat: String = "libsvm") extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("DatasetExample") { + head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + opt[String]("input") + .text(s"input path to dataset") + .action((x, c) => c.copy(input = x)) + opt[String]("dataFormat") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(input = x)) + checkConfig { params => + success + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DatasetExample with $params") + val sc = new SparkContext(conf) + val sqlContext = new SQLContext(sc) + import sqlContext._ // for implicit conversions + + // Load input data + val origData: RDD[LabeledPoint] = params.dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, params.input) + case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input) + } + println(s"Loaded ${origData.count()} instances from file: ${params.input}") + + // Convert input data to SchemaRDD explicitly. + val schemaRDD: SchemaRDD = origData + println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") + println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + + // Select columns, using implicit conversion to SchemaRDD. + val labelsSchemaRDD: SchemaRDD = origData.select('label) + val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + val numLabels = labels.count() + val meanLabel = labels.fold(0.0)(_ + _) / numLabels + println(s"Selected label column with average value $meanLabel") + + val featuresSchemaRDD: SchemaRDD = origData.select('features) + val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${featureSummary.mean.toString}") + + val tmpDir = Files.createTempDir() + tmpDir.deleteOnExit() + val outputDir = new File(tmpDir, "dataset").toString + println(s"Saving to $outputDir as Parquet file.") + schemaRDD.saveAsParquetFile(outputDir) + + println(s"Loading Parquet file with UDT from $outputDir.") + val newDataset = sqlContext.parquetFile(outputDir) + + println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") + val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( + (summary, feat) => summary.add(feat), + (sum1, sum2) => sum1.merge(sum2)) + println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}") + + sc.stop() + } + +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 4adc91d2fbe65..98f9d1689c8e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -22,11 +22,11 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} +import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -62,7 +62,10 @@ object DecisionTreeRunner { minInfoGain: Double = 0.0, numTrees: Int = 1, featureSubsetStrategy: String = "auto", - fracTest: Double = 0.2) + fracTest: Double = 0.2, + useNodeIdCache: Boolean = false, + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -102,6 +105,21 @@ object DecisionTreeRunner { .text(s"fraction of data to hold out for testing. If given option testInput, " + s"this option is ignored. default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) + opt[Boolean]("useNodeIdCache") + .text(s"whether to use node Id cache during training, " + + s"default: ${defaultParams.useNodeIdCache}") + .action((x, c) => c.copy(useNodeIdCache = x)) + opt[String]("checkpointDir") + .text(s"checkpoint directory where intermediate node Id caches will be stored, " + + s"default: ${defaultParams.checkpointDir match { + case Some(strVal) => strVal + case None => "None" + }}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"how often to checkpoint the node Id cache, " + + s"default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) opt[String]("testInput") .text(s"input path to test dataset. If given, option fracTest is ignored." + s" default: ${defaultParams.testInput}") @@ -136,18 +154,30 @@ object DecisionTreeRunner { } } - def run(params: Params) { - - val conf = new SparkConf().setAppName("DecisionTreeRunner") - val sc = new SparkContext(conf) - + /** + * Load training and test data from files. + * @param input Path to input dataset. + * @param dataFormat "libsvm" or "dense" + * @param testInput Path to test dataset. + * @param algo Classification or Regression + * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given. + * @return (training dataset, test dataset, number of classes), + * where the number of classes is inferred from data (and set to 0 for Regression) + */ + private[mllib] def loadDatasets( + sc: SparkContext, + input: String, + dataFormat: String, + testInput: String, + algo: Algo, + fracTest: Double): (RDD[LabeledPoint], RDD[LabeledPoint], Int) = { // Load training data and cache it. - val origExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.input).cache() - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input).cache() + val origExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, input).cache() + case "libsvm" => MLUtils.loadLibSVMFile(sc, input).cache() } // For classification, re-index classes if needed. - val (examples, classIndexMap, numClasses) = params.algo match { + val (examples, classIndexMap, numClasses) = algo match { case Classification => { // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() @@ -185,13 +215,14 @@ object DecisionTreeRunner { } // Create training, test sets. - val splits = if (params.testInput != "") { + val splits = if (testInput != "") { // Load testInput. - val origTestExamples = params.dataFormat match { - case "dense" => MLUtils.loadLabeledPoints(sc, params.testInput) - case "libsvm" => MLUtils.loadLibSVMFile(sc, params.testInput) + val numFeatures = examples.take(1)(0).features.size + val origTestExamples = dataFormat match { + case "dense" => MLUtils.loadLabeledPoints(sc, testInput) + case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } - params.algo match { + algo match { case Classification => { // classCounts: class --> # examples in class val testExamples = { @@ -208,17 +239,31 @@ object DecisionTreeRunner { } } else { // Split input into training, test. - examples.randomSplit(Array(1.0 - params.fracTest, params.fracTest)) + examples.randomSplit(Array(1.0 - fracTest, fracTest)) } val training = splits(0).cache() val test = splits(1).cache() + val numTraining = training.count() val numTest = test.count() - println(s"numTraining = $numTraining, numTest = $numTest.") examples.unpersist(blocking = false) + (training, test, numClasses) + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"DecisionTreeRunner with $params") + val sc = new SparkContext(conf) + + println(s"DecisionTreeRunner with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = loadDatasets(sc, params.input, params.dataFormat, + params.testInput, params.algo, params.fracTest) + val impurityCalculator = params.impurity match { case Gini => impurity.Gini case Entropy => impurity.Entropy @@ -233,9 +278,15 @@ object DecisionTreeRunner { maxBins = params.maxBins, numClassesForClassification = numClasses, minInstancesPerNode = params.minInstancesPerNode, - minInfoGain = params.minInfoGain) + minInfoGain = params.minInfoGain, + useNodeIdCache = params.useNodeIdCache, + checkpointDir = params.checkpointDir, + checkpointInterval = params.checkpointInterval) if (params.numTrees == 1) { + val startTime = System.nanoTime() val model = DecisionTree.train(training, strategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.numNodes < 20) { println(model.toDebugString) // Print full model. } else { @@ -259,8 +310,11 @@ object DecisionTreeRunner { } else { val randomSeed = Utils.random.nextInt() if (params.algo == Classification) { + val startTime = System.nanoTime() val model = RandomForest.trainClassifier(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { @@ -275,8 +329,11 @@ object DecisionTreeRunner { println(s"Test accuracy = $testAccuracy") } if (params.algo == Regression) { + val startTime = System.nanoTime() val model = RandomForest.trainRegressor(training, strategy, params.numTrees, params.featureSubsetStrategy, randomSeed) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") if (model.totalNumNodes < 30) { println(model.toDebugString) // Print full model. } else { @@ -295,19 +352,11 @@ object DecisionTreeRunner { /** * Calculates the mean squared error for regression. */ - private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { - data.map { y => - val err = tree.predict(y.features) - y.label - err * err - }.mean() - } - - /** - * Calculates the mean squared error for regression. - */ - private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = { + private[mllib] def meanSquaredError( + model: { def predict(features: Vector): Double }, + data: RDD[LabeledPoint]): Double = { data.map { y => - val err = tree.predict(y.features) - y.label + val err = model.predict(y.features) - y.label err * err }.mean() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala index 89dfa26c2299c..11e35598baf50 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala @@ -44,7 +44,7 @@ object DenseKMeans { input: String = null, k: Int = -1, numIterations: Int = 10, - initializationMode: InitializationMode = Parallel) + initializationMode: InitializationMode = Parallel) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala new file mode 100644 index 0000000000000..1def8b45a230c --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import scopt.OptionParser + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.evaluation.MulticlassMetrics +import org.apache.spark.mllib.tree.GradientBoostedTrees +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.util.Utils + +/** + * An example runner for Gradient Boosting using decision trees as weak learners. Run with + * {{{ + * ./bin/run-example mllib.GradientBoostedTreesRunner [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + * + * Note: This script treats all features as real-valued (not categorical). + * To include categorical features, modify categoricalFeaturesInfo. + */ +object GradientBoostedTreesRunner { + + case class Params( + input: String = null, + testInput: String = "", + dataFormat: String = "libsvm", + algo: String = "Classification", + maxDepth: Int = 5, + numIterations: Int = 10, + fracTest: Double = 0.2) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("GradientBoostedTrees") { + head("GradientBoostedTrees: an example decision tree app.") + opt[String]("algo") + .text(s"algorithm (${Algo.values.mkString(",")}), default: ${defaultParams.algo}") + .action((x, c) => c.copy(algo = x)) + opt[Int]("maxDepth") + .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") + .action((x, c) => c.copy(maxDepth = x)) + opt[Int]("numIterations") + .text(s"number of iterations of boosting," + s" default: ${defaultParams.numIterations}") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("fracTest") + .text(s"fraction of data to hold out for testing. If given option testInput, " + + s"this option is ignored. default: ${defaultParams.fracTest}") + .action((x, c) => c.copy(fracTest = x)) + opt[String]("testInput") + .text(s"input path to test dataset. If given, option fracTest is ignored." + + s" default: ${defaultParams.testInput}") + .action((x, c) => c.copy(testInput = x)) + opt[String]("") + .text("data format: libsvm (default), dense (deprecated in Spark v1.1)") + .action((x, c) => c.copy(dataFormat = x)) + arg[String]("") + .text("input path to labeled examples") + .required() + .action((x, c) => c.copy(input = x)) + checkConfig { params => + if (params.fracTest < 0 || params.fracTest > 1) { + failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].") + } else { + success + } + } + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + sys.exit(1) + } + } + + def run(params: Params) { + + val conf = new SparkConf().setAppName(s"GradientBoostedTreesRunner with $params") + val sc = new SparkContext(conf) + + println(s"GradientBoostedTreesRunner with parameters:\n$params") + + // Load training and test data and cache it. + val (training, test, numClasses) = DecisionTreeRunner.loadDatasets(sc, params.input, + params.dataFormat, params.testInput, Algo.withName(params.algo), params.fracTest) + + val boostingStrategy = BoostingStrategy.defaultParams(params.algo) + boostingStrategy.treeStrategy.numClassesForClassification = numClasses + boostingStrategy.numIterations = params.numIterations + boostingStrategy.treeStrategy.maxDepth = params.maxDepth + + val randomSeed = Utils.random.nextInt() + if (params.algo == "Classification") { + val startTime = System.nanoTime() + val model = GradientBoostedTrees.train(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainAccuracy = + new MulticlassMetrics(training.map(lp => (model.predict(lp.features), lp.label))) + .precision + println(s"Train accuracy = $trainAccuracy") + val testAccuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $testAccuracy") + } else if (params.algo == "Regression") { + val startTime = System.nanoTime() + val model = GradientBoostedTrees.train(training, boostingStrategy) + val elapsedTime = (System.nanoTime() - startTime) / 1e9 + println(s"Training time: $elapsedTime seconds") + if (model.totalNumNodes < 30) { + println(model.toDebugString) // Print full model. + } else { + println(model) // Print model summary. + } + val trainMSE = DecisionTreeRunner.meanSquaredError(model, training) + println(s"Train mean squared error = $trainMSE") + val testMSE = DecisionTreeRunner.meanSquaredError(model, test) + println(s"Test mean squared error = $testMSE") + } + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 05b7d66f8dffd..6a456ba7ec07b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.optimization.{SimpleUpdater, SquaredL2Updater, L1U * A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt`. * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object LinearRegression extends App { +object LinearRegression { object RegType extends Enumeration { type RegType = Value @@ -47,42 +47,44 @@ object LinearRegression extends App { numIterations: Int = 100, stepSize: Double = 1.0, regType: RegType = L2, - regParam: Double = 0.1) - - val defaultParams = Params() - - val parser = new OptionParser[Params]("LinearRegression") { - head("LinearRegression: an example app for linear regression.") - opt[Int]("numIterations") - .text("number of iterations") - .action((x, c) => c.copy(numIterations = x)) - opt[Double]("stepSize") - .text(s"initial step size, default: ${defaultParams.stepSize}") - .action((x, c) => c.copy(stepSize = x)) - opt[String]("regType") - .text(s"regularization type (${RegType.values.mkString(",")}), " + - s"default: ${defaultParams.regType}") - .action((x, c) => c.copy(regType = RegType.withName(x))) - opt[Double]("regParam") - .text(s"regularization parameter, default: ${defaultParams.regParam}") - arg[String]("") - .required() - .text("input paths to labeled examples in LIBSVM format") - .action((x, c) => c.copy(input = x)) - note( - """ - |For example, the following command runs this app on a synthetic dataset: - | - | bin/spark-submit --class org.apache.spark.examples.mllib.LinearRegression \ - | examples/target/scala-*/spark-examples-*.jar \ - | data/mllib/sample_linear_regression_data.txt - """.stripMargin) - } + regParam: Double = 0.01) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LinearRegression") { + head("LinearRegression: an example app for linear regression.") + opt[Int]("numIterations") + .text("number of iterations") + .action((x, c) => c.copy(numIterations = x)) + opt[Double]("stepSize") + .text(s"initial step size, default: ${defaultParams.stepSize}") + .action((x, c) => c.copy(stepSize = x)) + opt[String]("regType") + .text(s"regularization type (${RegType.values.mkString(",")}), " + + s"default: ${defaultParams.regType}") + .action((x, c) => c.copy(regType = RegType.withName(x))) + opt[Double]("regParam") + .text(s"regularization parameter, default: ${defaultParams.regParam}") + arg[String]("") + .required() + .text("input paths to labeled examples in LIBSVM format") + .action((x, c) => c.copy(input = x)) + note( + """ + |For example, the following command runs this app on a synthetic dataset: + | + | bin/spark-submit --class org.apache.spark.examples.mllib.LinearRegression \ + | examples/target/scala-*/spark-examples-*.jar \ + | data/mllib/sample_linear_regression_data.txt + """.stripMargin) + } - parser.parse(args, defaultParams).map { params => - run(params) - } getOrElse { - sys.exit(1) + parser.parse(args, defaultParams).map { params => + run(params) + } getOrElse { + sys.exit(1) + } } def run(params: Params) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala index 98aaedb9d7dc9..91a0a860d6c71 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala @@ -19,7 +19,6 @@ package org.apache.spark.examples.mllib import scala.collection.mutable -import com.esotericsoftware.kryo.Kryo import org.apache.log4j.{Level, Logger} import scopt.OptionParser @@ -27,7 +26,6 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.recommendation.{ALS, MatrixFactorizationModel, Rating} import org.apache.spark.rdd.RDD -import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator} /** * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/). @@ -40,13 +38,6 @@ import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator} */ object MovieLensALS { - class ALSRegistrator extends KryoRegistrator { - override def registerClasses(kryo: Kryo) { - kryo.register(classOf[Rating]) - kryo.register(classOf[mutable.BitSet]) - } - } - case class Params( input: String = null, kryo: Boolean = false, @@ -55,7 +46,7 @@ object MovieLensALS { rank: Int = 10, numUserBlocks: Int = -1, numProductBlocks: Int = -1, - implicitPrefs: Boolean = false) + implicitPrefs: Boolean = false) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() @@ -108,17 +99,18 @@ object MovieLensALS { def run(params: Params) { val conf = new SparkConf().setAppName(s"MovieLensALS with $params") if (params.kryo) { - conf.set("spark.serializer", classOf[KryoSerializer].getName) - .set("spark.kryo.registrator", classOf[ALSRegistrator].getName) + conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating])) .set("spark.kryoserializer.buffer.mb", "8") } val sc = new SparkContext(conf) Logger.getRootLogger.setLevel(Level.WARN) + val implicitPrefs = params.implicitPrefs + val ratings = sc.textFile(params.input).map { line => val fields = line.split("::") - if (params.implicitPrefs) { + if (implicitPrefs) { /* * MovieLens ratings are on a scale of 1-5: * 5: Must see diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 4532512c01f84..6e4e2d07f284b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -36,6 +36,7 @@ import org.apache.spark.{SparkConf, SparkContext} object MultivariateSummarizer { case class Params(input: String = "data/mllib/sample_linear_regression_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index f01b8266e3fe3..663c12734af68 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -33,6 +33,7 @@ import org.apache.spark.SparkContext._ object SampledRDDs { case class Params(input: String = "data/mllib/sample_binary_classification_data.txt") + extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala index 952fa2a5109a4..f1ff4e6911f5e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SparseNaiveBayes.scala @@ -37,7 +37,7 @@ object SparseNaiveBayes { input: String = null, minPartitions: Int = 0, numFeatures: Int = -1, - lambda: Double = 1.0) + lambda: Double = 1.0) extends AbstractParams[Params] def main(args: Array[String]) { val defaultParams = Params() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala new file mode 100644 index 0000000000000..33e5760aed997 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.clustering.StreamingKMeans +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} + +/** + * Estimate clusters on one stream of data and make predictions + * on another stream, where the data streams arrive as text files + * into two different directories. + * + * The rows of the training text files must be vector data in the form + * `[x1,x2,x3,...,xn]` + * Where n is the number of dimensions. + * + * The rows of the test text files must be labeled data in the form + * `(y,[x1,x2,x3,...,xn])` + * Where y is some identifier. n must be the same for train and test. + * + * Usage: StreamingKmeans + * + * To run on your local machine using the two directories `trainingDir` and `testDir`, + * with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call: + * $ bin/run-example \ + * org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2 + * + * As you add text files to `trainingDir` the clusters will continuously update. + * Anytime you add text files to `testDir`, you'll see predicted labels using the current model. + * + */ +object StreamingKMeans { + + def main(args: Array[String]) { + if (args.length != 5) { + System.err.println( + "Usage: StreamingKMeans " + + " ") + System.exit(1) + } + + val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression") + val ssc = new StreamingContext(conf, Seconds(args(2).toLong)) + + val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse) + val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse) + + val model = new StreamingKMeans() + .setK(args(3).toInt) + .setDecayFactor(1.0) + .setRandomCenters(args(4).toInt, 0.0) + + model.trainOn(trainingData) + model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print() + + ssc.start() + ssc.awaitTermination() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index e26f213e8afa8..227acc117502d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -27,15 +27,16 @@ object HiveFromSpark { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("HiveFromSpark") val sc = new SparkContext(sparkConf) + val path = s"${System.getenv("SPARK_HOME")}/examples/src/main/resources/kv1.txt" - // A local hive context creates an instance of the Hive Metastore in process, storing the + // A local hive context creates an instance of the Hive Metastore in process, storing // the warehouse data in the current directory. This location can be overridden by // specifying a second parameter to the constructor. val hiveContext = new HiveContext(sc) import hiveContext._ sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE src") // Queries are expressed in HiveQL println("Result of 'SELECT *': ") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 6af3a0f33efc2..19427e629f76d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -31,15 +31,13 @@ import org.apache.spark.util.IntParam /** * Counts words in text encoded with UTF8 received from the network every second. * - * Usage: NetworkWordCount + * Usage: RecoverableNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. directory to HDFS-compatible file system which checkpoint data * file to which the word counts will be appended * - * In local mode, should be 'local[n]' with n > 1 * and must be absolute paths * - * * To run this on your local machine, you need to first run a Netcat server * * `$ nc -lk 9999` @@ -54,22 +52,11 @@ import org.apache.spark.util.IntParam * checkpoint data exists in ~/checkpoint/, then it will create StreamingContext from * the checkpoint data. * - * To run this example in a local standalone cluster with automatic driver recovery, - * - * `$ bin/spark-class org.apache.spark.deploy.Client -s launch \ - * \ - * org.apache.spark.examples.streaming.RecoverableNetworkWordCount \ - * localhost 9999 ~/checkpoint ~/out` - * - * would typically be - * /examples/target/scala-XX/spark-examples....jar - * * Refer to the online documentation for more details. */ - object RecoverableNetworkWordCount { - def createContext(ip: String, port: Int, outputPath: String) = { + def createContext(ip: String, port: Int, outputPath: String, checkpointDirectory: String) = { // If you do not see this printed, that means the StreamingContext has been loaded // from the new checkpoint @@ -79,6 +66,7 @@ object RecoverableNetworkWordCount { val sparkConf = new SparkConf().setAppName("RecoverableNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) + ssc.checkpoint(checkpointDirectory) // Create a socket stream on target ip:port and count the // words in input stream of \n delimited text (eg. generated by 'nc') @@ -114,7 +102,7 @@ object RecoverableNetworkWordCount { val Array(ip, IntParam(port), checkpointDirectory, outputPath) = args val ssc = StreamingContext.getOrCreate(checkpointDirectory, () => { - createContext(ip, port, outputPath) + createContext(ip, port, outputPath, checkpointDirectory) }) ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index a4d159bf38377..ed186ea5650c4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -18,12 +18,13 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf +import org.apache.spark.HashPartitioner import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every - * second. + * second starting with initial value of word count. * Usage: StatefulNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive * data. @@ -51,12 +52,19 @@ object StatefulNetworkWordCount { Some(currentCount + previousCount) } + val newUpdateFunc = (iterator: Iterator[(String, Seq[Int], Option[Int])]) => { + iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s))) + } + val sparkConf = new SparkConf().setAppName("StatefulNetworkWordCount") // Create the context with a 1 second batch size val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") - // Create a NetworkInputDStream on target ip:port and count the + // Initial RDD input to updateStateByKey + val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) + + // Create a ReceiverInputDStream on target ip:port and count the // words in input stream of \n delimited test (eg. generated by 'nc') val lines = ssc.socketTextStream(args(0), args(1).toInt) val words = lines.flatMap(_.split(" ")) @@ -64,7 +72,8 @@ object StatefulNetworkWordCount { // Update the cumulative count using updateStateByKey // This will give a Dstream made of state (which is the cumulative count of the words) - val stateDstream = wordDstream.updateStateByKey[Int](updateFunc) + val stateDstream = wordDstream.updateStateByKey[Int](newUpdateFunc, + new HashPartitioner (ssc.sparkContext.defaultParallelism), true, initialRDD) stateDstream.print() ssc.start() ssc.awaitTermination() diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index d9b886eff77cc..55226c0a6df60 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -50,7 +50,7 @@ object PageViewStream { val ssc = new StreamingContext("local[2]", "PageViewStream", Seconds(1), System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass).toSeq) - // Create a NetworkInputDStream on target host:port and convert each line to a PageView + // Create a ReceiverInputDStream on target host:port and convert each line to a PageView val pageViews = ssc.socketTextStream(host, port) .flatMap(_.split("\n")) .map(PageView.fromString(_)) diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index ac291bd4fde20..72618b6515f83 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..4411d6e20c52a --- /dev/null +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -0,0 +1,29 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the file streaming/target/unit-tests.log +log4j.rootCategory=INFO, file +# log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN + diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index a2b2cc6149d95..650b2fbe1c142 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -159,6 +159,7 @@ class SparkSinkSuite extends FunSuite { channelContext.put("transactionCapacity", 1000.toString) channelContext.put("keep-alive", 0.toString) channelContext.putAll(overrides) + channel.setName(scala.util.Random.nextString(10)) channel.configure(channelContext) val sink = new SparkSink() diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 7d31e32283d88..a682f0e8471d8 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,19 +39,13 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} + provided
    org.apache.spark spark-streaming-flume-sink_${scala.binary.version} ${project.version} - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test - org.apache.flume flume-ng-sdk diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 4b2ea45fb81d0..2de2a7926bfd1 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -66,7 +66,7 @@ class SparkFlumeEvent() extends Externalizable { var event : AvroFlumeEvent = new AvroFlumeEvent() /* De-serialize from bytes. */ - def readExternal(in: ObjectInput) { + def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { val bodyLength = in.readInt() val bodyBuff = new Array[Byte](bodyLength) in.readFully(bodyBuff) @@ -93,7 +93,7 @@ class SparkFlumeEvent() extends Externalizable { } /* Serialize to bytes. */ - def writeExternal(out: ObjectOutput) { + def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { val body = event.getBody.array() out.writeInt(body.length) out.write(body) diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala new file mode 100644 index 0000000000000..1a900007b696b --- /dev/null +++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import java.io.{IOException, ObjectInputStream} + +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} +import org.apache.spark.util.Utils + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * + * The buffer contains a sequence of RDD's, each containing a sequence of items + */ +class TestOutputStream[T: ClassTag](parent: DStream[T], + val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]()) + extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.collect() + output += collected + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { + ois.defaultReadObject() + output.clear() + } +} diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index 32a19787a28e1..b57a1c71e35b9 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -20,9 +20,6 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import java.util.concurrent.{Callable, ExecutorCompletionService, Executors} -import java.util.Random - -import org.apache.spark.TestUtils import scala.collection.JavaConversions._ import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} @@ -32,20 +29,35 @@ import org.apache.flume.channel.MemoryChannel import org.apache.flume.conf.Configurables import org.apache.flume.event.EventBuilder +import org.scalatest.{BeforeAndAfter, FunSuite} + +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream import org.apache.spark.streaming.util.ManualClock -import org.apache.spark.streaming.{TestSuiteBase, TestOutputStream, StreamingContext} +import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.streaming.flume.sink._ import org.apache.spark.util.Utils -class FlumePollingStreamSuite extends TestSuiteBase { +class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging { val batchCount = 5 val eventsPerBatch = 100 val totalEventsPerChannel = batchCount * eventsPerBatch val channelCapacity = 5000 val maxAttempts = 5 + val batchDuration = Seconds(1) + + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName(this.getClass.getSimpleName) + + def beforeFunction() { + logInfo("Using manual clock") + conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") + } + + before(beforeFunction()) test("flume polling test") { testMultipleTimes(testFlumePolling) @@ -145,11 +157,16 @@ class FlumePollingStreamSuite extends TestSuiteBase { outputStream.register() ssc.start() - writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) - assertChannelIsEmpty(channel) - assertChannelIsEmpty(channel2) - sink.stop() - channel.stop() + try { + writeAndVerify(Seq(channel, channel2), ssc, outputBuffer) + assertChannelIsEmpty(channel) + assertChannelIsEmpty(channel2) + } finally { + sink.stop() + sink2.stop() + channel.stop() + channel2.stop() + } } def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext, @@ -224,4 +241,5 @@ class FlumePollingStreamSuite extends TestSuiteBase { null } } + } diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 33235d150b4a5..13943ed5442b9 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -17,103 +17,141 @@ package org.apache.spark.streaming.flume -import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} - -import java.net.InetSocketAddress +import java.net.{InetSocketAddress, ServerSocket} import java.nio.ByteBuffer import java.nio.charset.Charset +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import scala.concurrent.duration._ +import scala.language.postfixOps + import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor +import org.apache.flume.source.avro import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} +import org.jboss.netty.channel.ChannelPipeline +import org.jboss.netty.channel.socket.SocketChannel +import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory +import org.jboss.netty.handler.codec.compression._ +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} +import org.scalatest.concurrent.Eventually._ +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{TestOutputStream, StreamingContext, TestSuiteBase} -import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.{Milliseconds, StreamingContext, TestOutputStream} +import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerReceiverStarted} import org.apache.spark.util.Utils -import org.jboss.netty.channel.ChannelPipeline -import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory -import org.jboss.netty.channel.socket.SocketChannel -import org.jboss.netty.handler.codec.compression._ +class FlumeStreamSuite extends FunSuite with BeforeAndAfter with Matchers with Logging { + val conf = new SparkConf().setMaster("local[4]").setAppName("FlumeStreamSuite") + + var ssc: StreamingContext = null + var transceiver: NettyTransceiver = null -class FlumeStreamSuite extends TestSuiteBase { + after { + if (ssc != null) { + ssc.stop() + } + if (transceiver != null) { + transceiver.close() + } + } test("flume input stream") { - runFlumeStreamTest(false) + testFlumeStream(testCompression = false) } test("flume input compressed stream") { - runFlumeStreamTest(true) + testFlumeStream(testCompression = true) + } + + /** Run test on flume stream */ + private def testFlumeStream(testCompression: Boolean): Unit = { + val input = (1 to 100).map { _.toString } + val testPort = findFreePort() + val outputBuffer = startContext(testPort, testCompression) + writeAndVerify(input, testPort, outputBuffer, testCompression) + } + + /** Find a free port */ + private def findFreePort(): Int = { + Utils.startServiceOnPort(23456, (trialPort: Int) => { + val socket = new ServerSocket(trialPort) + socket.close() + (null, trialPort) + })._2 } - - def runFlumeStreamTest(enableDecompression: Boolean) { - // Set up the streaming context and input streams - val ssc = new StreamingContext(conf, batchDuration) - val (flumeStream, testPort) = - Utils.startServiceOnPort(9997, (trialPort: Int) => { - val dstream = FlumeUtils.createStream( - ssc, "localhost", trialPort, StorageLevel.MEMORY_AND_DISK, enableDecompression) - (dstream, trialPort) - }) + /** Setup and start the streaming context */ + private def startContext( + testPort: Int, testCompression: Boolean): (ArrayBuffer[Seq[SparkFlumeEvent]]) = { + ssc = new StreamingContext(conf, Milliseconds(200)) + val flumeStream = FlumeUtils.createStream( + ssc, "localhost", testPort, StorageLevel.MEMORY_AND_DISK, testCompression) val outputBuffer = new ArrayBuffer[Seq[SparkFlumeEvent]] with SynchronizedBuffer[Seq[SparkFlumeEvent]] val outputStream = new TestOutputStream(flumeStream, outputBuffer) outputStream.register() ssc.start() + outputBuffer + } - val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val input = Seq(1, 2, 3, 4, 5) - Thread.sleep(1000) - val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)) - var client: AvroSourceProtocol = null - - if (enableDecompression) { - client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], - new NettyTransceiver(new InetSocketAddress("localhost", testPort), - new CompressionChannelFactory(6))) - } else { - client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], transceiver) - } + /** Send data to the flume receiver and verify whether the data was received */ + private def writeAndVerify( + input: Seq[String], + testPort: Int, + outputBuffer: ArrayBuffer[Seq[SparkFlumeEvent]], + enableCompression: Boolean + ) { + val testAddress = new InetSocketAddress("localhost", testPort) - for (i <- 0 until input.size) { + val inputEvents = input.map { item => val event = new AvroFlumeEvent - event.setBody(ByteBuffer.wrap(input(i).toString.getBytes("utf-8"))) + event.setBody(ByteBuffer.wrap(item.getBytes("UTF-8"))) event.setHeaders(Map[CharSequence, CharSequence]("test" -> "header")) - client.append(event) - Thread.sleep(500) - clock.addToTime(batchDuration.milliseconds) + event } - Thread.sleep(1000) - - val startTime = System.currentTimeMillis() - while (outputBuffer.size < input.size && System.currentTimeMillis() - startTime < maxWaitTimeMillis) { - logInfo("output.size = " + outputBuffer.size + ", input.size = " + input.size) - Thread.sleep(100) + eventually(timeout(10 seconds), interval(100 milliseconds)) { + // if last attempted transceiver had succeeded, close it + if (transceiver != null) { + transceiver.close() + transceiver = null + } + + // Create transceiver + transceiver = { + if (enableCompression) { + new NettyTransceiver(testAddress, new CompressionChannelFactory(6)) + } else { + new NettyTransceiver(testAddress) + } + } + + // Create Avro client with the transceiver + val client = SpecificRequestor.getClient(classOf[AvroSourceProtocol], transceiver) + client should not be null + + // Send data + val status = client.appendBatch(inputEvents.toList) + status should be (avro.Status.OK) } - Thread.sleep(1000) - val timeTaken = System.currentTimeMillis() - startTime - assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") - logInfo("Stopping context") - ssc.stop() - - val decoder = Charset.forName("UTF-8").newDecoder() - - assert(outputBuffer.size === input.length) - for (i <- 0 until outputBuffer.size) { - assert(outputBuffer(i).size === 1) - val str = decoder.decode(outputBuffer(i).head.event.getBody) - assert(str.toString === input(i).toString) - assert(outputBuffer(i).head.event.getHeaders.get("test") === "header") + + val decoder = Charset.forName("UTF-8").newDecoder() + eventually(timeout(10 seconds), interval(100 milliseconds)) { + val outputEvents = outputBuffer.flatten.map { _.event } + outputEvents.foreach { + event => + event.getHeaders.get("test") should be("header") + } + val output = outputEvents.map(event => decoder.decode(event.getBody()).toString) + output should be (input) } } - class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { + /** Class to create socket channel with compression */ + private class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { override def newChannel(pipeline: ChannelPipeline): SocketChannel = { val encoder = new ZlibEncoder(compressionLevel) pipeline.addFirst("deflater", encoder) diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 2067c473f0e3f..b3f44471cd326 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided org.apache.kafka diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index e20e2c8f26991..4d26b640e8d74 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -17,23 +17,21 @@ package org.apache.spark.streaming.kafka +import java.util.Properties + import scala.collection.Map import scala.reflect.{classTag, ClassTag} -import java.util.Properties -import java.util.concurrent.Executors - -import kafka.consumer._ +import kafka.consumer.{KafkaStream, Consumer, ConsumerConfig, ConsumerConnector} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties -import kafka.utils.ZKStringSerializer -import org.I0Itec.zkclient._ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.Utils /** * Input stream that pulls messages from a Kafka Broker. @@ -53,12 +51,16 @@ class KafkaInputDStream[ @transient ssc_ : StreamingContext, kafkaParams: Map[String, String], topics: Map[String, Int], + useReliableReceiver: Boolean, storageLevel: StorageLevel ) extends ReceiverInputDStream[(K, V)](ssc_) with Logging { def getReceiver(): Receiver[(K, V)] = { - new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) - .asInstanceOf[Receiver[(K, V)]] + if (!useReliableReceiver) { + new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) + } else { + new ReliableKafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel) + } } } @@ -71,14 +73,15 @@ class KafkaReceiver[ kafkaParams: Map[String, String], topics: Map[String, Int], storageLevel: StorageLevel - ) extends Receiver[Any](storageLevel) with Logging { + ) extends Receiver[(K, V)](storageLevel) with Logging { // Connection to Kafka - var consumerConnector : ConsumerConnector = null + var consumerConnector: ConsumerConnector = null def onStop() { if (consumerConnector != null) { consumerConnector.shutdown() + consumerConnector = null } } @@ -97,12 +100,6 @@ class KafkaReceiver[ consumerConnector = Consumer.create(consumerConfig) logInfo("Connected to " + zkConnect) - // When auto.offset.reset is defined, it is our responsibility to try and whack the - // consumer group zk node. - if (kafkaParams.contains("auto.offset.reset")) { - tryZookeeperConsumerGroupCleanup(zkConnect, kafkaParams("group.id")) - } - val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) .newInstance(consumerConfig.props) .asInstanceOf[Decoder[K]] @@ -110,11 +107,11 @@ class KafkaReceiver[ .newInstance(consumerConfig.props) .asInstanceOf[Decoder[V]] - // Create Threads for each Topic/Message Stream we are listening + // Create threads for each topic/message Stream we are listening val topicMessageStreams = consumerConnector.createMessageStreams( topics, keyDecoder, valueDecoder) - val executorPool = Executors.newFixedThreadPool(topics.values.sum) + val executorPool = Utils.newDaemonFixedThreadPool(topics.values.sum, "KafkaMessageHandler") try { // Start the messages handler for each partition topicMessageStreams.values.foreach { streams => @@ -125,13 +122,15 @@ class KafkaReceiver[ } } - // Handles Kafka Messages - private class MessageHandler[K: ClassTag, V: ClassTag](stream: KafkaStream[K, V]) + // Handles Kafka messages + private class MessageHandler(stream: KafkaStream[K, V]) extends Runnable { def run() { logInfo("Starting MessageHandler.") try { - for (msgAndMetadata <- stream) { + val streamIterator = stream.iterator() + while (streamIterator.hasNext()) { + val msgAndMetadata = streamIterator.next() store((msgAndMetadata.key, msgAndMetadata.message)) } } catch { @@ -139,26 +138,4 @@ class KafkaReceiver[ } } } - - // It is our responsibility to delete the consumer group when specifying auto.offset.reset. This - // is because Kafka 0.7.2 only honors this param when the group is not in zookeeper. - // - // The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied - // from Kafka's ConsoleConsumer. See code related to 'auto.offset.reset' when it is set to - // 'smallest'/'largest': - // scalastyle:off - // https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala - // scalastyle:on - private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) { - val dir = "/consumers/" + groupId - logInfo("Cleaning up temporary Zookeeper data under " + dir + ".") - val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer) - try { - zk.deleteRecursive(dir) - } catch { - case e: Throwable => logWarning("Error cleaning up temporary Zookeeper data", e) - } finally { - zk.close() - } - } } diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index 48668f763e41e..b4ac929e0c070 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -17,19 +17,18 @@ package org.apache.spark.streaming.kafka -import scala.reflect.ClassTag -import scala.collection.JavaConversions._ - import java.lang.{Integer => JInt} import java.util.{Map => JMap} +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + import kafka.serializer.{Decoder, StringDecoder} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext, JavaPairDStream} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} - +import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.ReceiverInputDStream object KafkaUtils { /** @@ -71,7 +70,8 @@ object KafkaUtils { topics: Map[String, Int], storageLevel: StorageLevel ): ReceiverInputDStream[(K, V)] = { - new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, storageLevel) + val walEnabled = ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false) + new KafkaInputDStream[K, V, U, T](ssc, kafkaParams, topics, walEnabled, storageLevel) } /** @@ -100,7 +100,6 @@ object KafkaUtils { * @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed * in its own thread. * @param storageLevel RDD storage level. - * */ def createStream( jssc: JavaStreamingContext, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala new file mode 100644 index 0000000000000..be734b80272d1 --- /dev/null +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -0,0 +1,282 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + +import java.util.Properties +import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap} + +import scala.collection.{Map, mutable} +import scala.reflect.{ClassTag, classTag} + +import kafka.common.TopicAndPartition +import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} +import kafka.message.MessageAndMetadata +import kafka.serializer.Decoder +import kafka.utils.{VerifiableProperties, ZKGroupTopicDirs, ZKStringSerializer, ZkUtils} +import org.I0Itec.zkclient.ZkClient + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver} +import org.apache.spark.util.Utils + +/** + * ReliableKafkaReceiver offers the ability to reliably store data into BlockManager without loss. + * It is turned off by default and will be enabled when + * spark.streaming.receiver.writeAheadLog.enable is true. The difference compared to KafkaReceiver + * is that this receiver manages topic-partition/offset itself and updates the offset information + * after data is reliably stored as write-ahead log. Offsets will only be updated when data is + * reliably stored, so the potential data loss problem of KafkaReceiver can be eliminated. + * + * Note: ReliableKafkaReceiver will set auto.commit.enable to false to turn off automatic offset + * commit mechanism in Kafka consumer. So setting this configuration manually within kafkaParams + * will not take effect. + */ +private[streaming] +class ReliableKafkaReceiver[ + K: ClassTag, + V: ClassTag, + U <: Decoder[_]: ClassTag, + T <: Decoder[_]: ClassTag]( + kafkaParams: Map[String, String], + topics: Map[String, Int], + storageLevel: StorageLevel) + extends Receiver[(K, V)](storageLevel) with Logging { + + private val groupId = kafkaParams("group.id") + private val AUTO_OFFSET_COMMIT = "auto.commit.enable" + private def conf = SparkEnv.get.conf + + /** High level consumer to connect to Kafka. */ + private var consumerConnector: ConsumerConnector = null + + /** zkClient to connect to Zookeeper to commit the offsets. */ + private var zkClient: ZkClient = null + + /** + * A HashMap to manage the offset for each topic/partition, this HashMap is called in + * synchronized block, so mutable HashMap will not meet concurrency issue. + */ + private var topicPartitionOffsetMap: mutable.HashMap[TopicAndPartition, Long] = null + + /** A concurrent HashMap to store the stream block id and related offset snapshot. */ + private var blockOffsetMap: ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]] = null + + /** + * Manage the BlockGenerator in receiver itself for better managing block store and offset + * commit. + */ + private var blockGenerator: BlockGenerator = null + + /** Thread pool running the handlers for receiving message from multiple topics and partitions. */ + private var messageHandlerThreadPool: ThreadPoolExecutor = null + + override def onStart(): Unit = { + logInfo(s"Starting Kafka Consumer Stream with group: $groupId") + + // Initialize the topic-partition / offset hash map. + topicPartitionOffsetMap = new mutable.HashMap[TopicAndPartition, Long] + + // Initialize the stream block id / offset snapshot hash map. + blockOffsetMap = new ConcurrentHashMap[StreamBlockId, Map[TopicAndPartition, Long]]() + + // Initialize the block generator for storing Kafka message. + blockGenerator = new BlockGenerator(new GeneratedBlockHandler, streamId, conf) + + if (kafkaParams.contains(AUTO_OFFSET_COMMIT) && kafkaParams(AUTO_OFFSET_COMMIT) == "true") { + logWarning(s"$AUTO_OFFSET_COMMIT should be set to false in ReliableKafkaReceiver, " + + "otherwise we will manually set it to false to turn off auto offset commit in Kafka") + } + + val props = new Properties() + kafkaParams.foreach(param => props.put(param._1, param._2)) + // Manually set "auto.commit.enable" to "false" no matter user explicitly set it to true, + // we have to make sure this property is set to false to turn off auto commit mechanism in + // Kafka. + props.setProperty(AUTO_OFFSET_COMMIT, "false") + + val consumerConfig = new ConsumerConfig(props) + + assert(!consumerConfig.autoCommitEnable) + + logInfo(s"Connecting to Zookeeper: ${consumerConfig.zkConnect}") + consumerConnector = Consumer.create(consumerConfig) + logInfo(s"Connected to Zookeeper: ${consumerConfig.zkConnect}") + + zkClient = new ZkClient(consumerConfig.zkConnect, consumerConfig.zkSessionTimeoutMs, + consumerConfig.zkConnectionTimeoutMs, ZKStringSerializer) + + messageHandlerThreadPool = Utils.newDaemonFixedThreadPool( + topics.values.sum, "KafkaMessageHandler") + + blockGenerator.start() + + val keyDecoder = classTag[U].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(consumerConfig.props) + .asInstanceOf[Decoder[K]] + + val valueDecoder = classTag[T].runtimeClass.getConstructor(classOf[VerifiableProperties]) + .newInstance(consumerConfig.props) + .asInstanceOf[Decoder[V]] + + val topicMessageStreams = consumerConnector.createMessageStreams( + topics, keyDecoder, valueDecoder) + + topicMessageStreams.values.foreach { streams => + streams.foreach { stream => + messageHandlerThreadPool.submit(new MessageHandler(stream)) + } + } + } + + override def onStop(): Unit = { + if (messageHandlerThreadPool != null) { + messageHandlerThreadPool.shutdown() + messageHandlerThreadPool = null + } + + if (consumerConnector != null) { + consumerConnector.shutdown() + consumerConnector = null + } + + if (zkClient != null) { + zkClient.close() + zkClient = null + } + + if (blockGenerator != null) { + blockGenerator.stop() + blockGenerator = null + } + + if (topicPartitionOffsetMap != null) { + topicPartitionOffsetMap.clear() + topicPartitionOffsetMap = null + } + + if (blockOffsetMap != null) { + blockOffsetMap.clear() + blockOffsetMap = null + } + } + + /** Store a Kafka message and the associated metadata as a tuple. */ + private def storeMessageAndMetadata( + msgAndMetadata: MessageAndMetadata[K, V]): Unit = { + val topicAndPartition = TopicAndPartition(msgAndMetadata.topic, msgAndMetadata.partition) + val data = (msgAndMetadata.key, msgAndMetadata.message) + val metadata = (topicAndPartition, msgAndMetadata.offset) + blockGenerator.addDataWithCallback(data, metadata) + } + + /** Update stored offset */ + private def updateOffset(topicAndPartition: TopicAndPartition, offset: Long): Unit = { + topicPartitionOffsetMap.put(topicAndPartition, offset) + } + + /** + * Remember the current offsets for each topic and partition. This is called when a block is + * generated. + */ + private def rememberBlockOffsets(blockId: StreamBlockId): Unit = { + // Get a snapshot of current offset map and store with related block id. + val offsetSnapshot = topicPartitionOffsetMap.toMap + blockOffsetMap.put(blockId, offsetSnapshot) + topicPartitionOffsetMap.clear() + } + + /** Store the ready-to-be-stored block and commit the related offsets to zookeeper. */ + private def storeBlockAndCommitOffset( + blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + store(arrayBuffer.asInstanceOf[mutable.ArrayBuffer[(K, V)]]) + Option(blockOffsetMap.get(blockId)).foreach(commitOffset) + blockOffsetMap.remove(blockId) + } + + /** + * Commit the offset of Kafka's topic/partition, the commit mechanism follow Kafka 0.8.x's + * metadata schema in Zookeeper. + */ + private def commitOffset(offsetMap: Map[TopicAndPartition, Long]): Unit = { + if (zkClient == null) { + val thrown = new IllegalStateException("Zookeeper client is unexpectedly null") + stop("Zookeeper client is not initialized before commit offsets to ZK", thrown) + return + } + + for ((topicAndPart, offset) <- offsetMap) { + try { + val topicDirs = new ZKGroupTopicDirs(groupId, topicAndPart.topic) + val zkPath = s"${topicDirs.consumerOffsetDir}/${topicAndPart.partition}" + + ZkUtils.updatePersistentPath(zkClient, zkPath, offset.toString) + } catch { + case e: Exception => + logWarning(s"Exception during commit offset $offset for topic" + + s"${topicAndPart.topic}, partition ${topicAndPart.partition}", e) + } + + logInfo(s"Committed offset $offset for topic ${topicAndPart.topic}, " + + s"partition ${topicAndPart.partition}") + } + } + + /** Class to handle received Kafka message. */ + private final class MessageHandler(stream: KafkaStream[K, V]) extends Runnable { + override def run(): Unit = { + while (!isStopped) { + try { + val streamIterator = stream.iterator() + while (streamIterator.hasNext) { + storeMessageAndMetadata(streamIterator.next) + } + } catch { + case e: Exception => + logError("Error handling message", e) + } + } + } + } + + /** Class to handle blocks generated by the block generator. */ + private final class GeneratedBlockHandler extends BlockGeneratorListener { + + def onAddData(data: Any, metadata: Any): Unit = { + // Update the offset of the data that was added to the generator + if (metadata != null) { + val (topicAndPartition, offset) = metadata.asInstanceOf[(TopicAndPartition, Long)] + updateOffset(topicAndPartition, offset) + } + } + + def onGenerateBlock(blockId: StreamBlockId): Unit = { + // Remember the offsets of topics/partitions when a block has been generated + rememberBlockOffsets(blockId) + } + + def onPushBlock(blockId: StreamBlockId, arrayBuffer: mutable.ArrayBuffer[_]): Unit = { + // Store block and commit the blocks offset + storeBlockAndCommitOffset(blockId, arrayBuffer) + } + + def onError(message: String, throwable: Throwable): Unit = { + reportError(message, throwable) + } + } +} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index efb0099c7c850..6e1abf3f385ee 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -20,7 +20,10 @@ import java.io.Serializable; import java.util.HashMap; import java.util.List; +import java.util.Random; +import org.apache.spark.SparkConf; +import org.apache.spark.streaming.Duration; import scala.Predef; import scala.Tuple2; import scala.collection.JavaConverters; @@ -32,8 +35,6 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; import org.apache.spark.storage.StorageLevel; -import org.apache.spark.streaming.Duration; -import org.apache.spark.streaming.LocalJavaStreamingContext; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -42,25 +43,27 @@ import org.junit.After; import org.junit.Before; -public class JavaKafkaStreamSuite extends LocalJavaStreamingContext implements Serializable { - private transient KafkaStreamSuite testSuite = new KafkaStreamSuite(); +public class JavaKafkaStreamSuite implements Serializable { + private transient JavaStreamingContext ssc = null; + private transient Random random = new Random(); + private transient KafkaStreamSuiteBase suiteBase = null; @Before - @Override public void setUp() { - testSuite.beforeFunction(); + suiteBase = new KafkaStreamSuiteBase() { }; + suiteBase.setupKafka(); System.clearProperty("spark.driver.port"); - //System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock"); - ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + SparkConf sparkConf = new SparkConf() + .setMaster("local[4]").setAppName(this.getClass().getSimpleName()); + ssc = new JavaStreamingContext(sparkConf, new Duration(500)); } @After - @Override public void tearDown() { ssc.stop(); ssc = null; System.clearProperty("spark.driver.port"); - testSuite.afterFunction(); + suiteBase.tearDownKafka(); } @Test @@ -74,15 +77,15 @@ public void testKafkaStream() throws InterruptedException { sent.put("b", 3); sent.put("c", 10); - testSuite.createTopic(topic); + suiteBase.createTopic(topic); HashMap tmp = new HashMap(sent); - testSuite.produceAndSendMessage(topic, - JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( - Predef.>conforms())); + suiteBase.produceAndSendMessage(topic, + JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap( + Predef.>conforms())); HashMap kafkaParams = new HashMap(); - kafkaParams.put("zookeeper.connect", testSuite.zkHost() + ":" + testSuite.zkPort()); - kafkaParams.put("group.id", "test-consumer-" + KafkaTestUtils.random().nextInt(10000)); + kafkaParams.put("zookeeper.connect", suiteBase.zkAddress()); + kafkaParams.put("group.id", "test-consumer-" + random.nextInt(10000)); kafkaParams.put("auto.offset.reset", "smallest"); JavaPairDStream stream = KafkaUtils.createStream(ssc, @@ -124,11 +127,16 @@ public Void call(JavaPairRDD rdd) throws Exception { ); ssc.start(); - ssc.awaitTermination(3000); - + long startTime = System.currentTimeMillis(); + boolean sizeMatches = false; + while (!sizeMatches && System.currentTimeMillis() - startTime < 20000) { + sizeMatches = sent.size() == result.size(); + Thread.sleep(200); + } Assert.assertEquals(sent.size(), result.size()); for (String k : sent.keySet()) { Assert.assertEquals(sent.get(k).intValue(), result.get(k).intValue()); } + ssc.stop(); } } diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala index 6943326eb750e..b19c053ebfc44 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala @@ -19,51 +19,57 @@ package org.apache.spark.streaming.kafka import java.io.File import java.net.InetSocketAddress -import java.util.{Properties, Random} +import java.util.Properties import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random import kafka.admin.CreateTopicCommand import kafka.common.{KafkaException, TopicAndPartition} -import kafka.producer.{KeyedMessage, ProducerConfig, Producer} -import kafka.utils.ZKStringSerializer +import kafka.producer.{KeyedMessage, Producer, ProducerConfig} import kafka.serializer.{StringDecoder, StringEncoder} import kafka.server.{KafkaConfig, KafkaServer} - +import kafka.utils.ZKStringSerializer import org.I0Itec.zkclient.ZkClient +import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually -import org.apache.zookeeper.server.ZooKeeperServer -import org.apache.zookeeper.server.NIOServerCnxnFactory - -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} import org.apache.spark.util.Utils -class KafkaStreamSuite extends TestSuiteBase { - import KafkaTestUtils._ - - val zkHost = "localhost" - var zkPort: Int = 0 - val zkConnectionTimeout = 6000 - val zkSessionTimeout = 6000 - - protected var brokerPort = 9092 - protected var brokerConf: KafkaConfig = _ - protected var zookeeper: EmbeddedZookeeper = _ - protected var zkClient: ZkClient = _ - protected var server: KafkaServer = _ - protected var producer: Producer[String, String] = _ - - override def useManualClock = false - - override def beforeFunction() { +/** + * This is an abstract base class for Kafka testsuites. This has the functionality to set up + * and tear down local Kafka servers, and to push data using Kafka producers. + */ +abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging { + + var zkAddress: String = _ + var zkClient: ZkClient = _ + + private val zkHost = "localhost" + private val zkConnectionTimeout = 6000 + private val zkSessionTimeout = 6000 + private var zookeeper: EmbeddedZookeeper = _ + private var zkPort: Int = 0 + private var brokerPort = 9092 + private var brokerConf: KafkaConfig = _ + private var server: KafkaServer = _ + private var producer: Producer[String, String] = _ + + def setupKafka() { // Zookeeper server startup zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort") // Get the actual zookeeper binding port zkPort = zookeeper.actualPort + zkAddress = s"$zkHost:$zkPort" logInfo("==================== 0 ====================") - zkClient = new ZkClient(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, + zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer) logInfo("==================== 1 ====================") @@ -71,7 +77,7 @@ class KafkaStreamSuite extends TestSuiteBase { var bindSuccess: Boolean = false while(!bindSuccess) { try { - val brokerProps = getBrokerConfig(brokerPort, s"$zkHost:$zkPort") + val brokerProps = getBrokerConfig() brokerConf = new KafkaConfig(brokerProps) server = new KafkaServer(brokerConf) logInfo("==================== 2 ====================") @@ -89,53 +95,30 @@ class KafkaStreamSuite extends TestSuiteBase { Thread.sleep(2000) logInfo("==================== 4 ====================") - super.beforeFunction() } - override def afterFunction() { - producer.close() - server.shutdown() - brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - - zkClient.close() - zookeeper.shutdown() - - super.afterFunction() - } - - test("Kafka input stream") { - val ssc = new StreamingContext(master, framework, batchDuration) - val topic = "topic1" - val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) - createTopic(topic) - produceAndSendMessage(topic, sent) + def tearDownKafka() { + if (producer != null) { + producer.close() + producer = null + } - val kafkaParams = Map("zookeeper.connect" -> s"$zkHost:$zkPort", - "group.id" -> s"test-consumer-${random.nextInt(10000)}", - "auto.offset.reset" -> "smallest") + if (server != null) { + server.shutdown() + server = null + } - val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( - ssc, - kafkaParams, - Map(topic -> 1), - StorageLevel.MEMORY_ONLY) - val result = new mutable.HashMap[String, Long]() - stream.map { case (k, v) => v } - .countByValue() - .foreachRDD { r => - val ret = r.collect() - ret.toMap.foreach { kv => - val count = result.getOrElseUpdate(kv._1, 0) + kv._2 - result.put(kv._1, count) - } - } - ssc.start() - ssc.awaitTermination(3000) + brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) } - assert(sent.size === result.size) - sent.keys.foreach { k => assert(sent(k) === result(k).toInt) } + if (zkClient != null) { + zkClient.close() + zkClient = null + } - ssc.stop() + if (zookeeper != null) { + zookeeper.shutdown() + zookeeper = null + } } private def createTestMessage(topic: String, sent: Map[String, Int]) @@ -150,58 +133,43 @@ class KafkaStreamSuite extends TestSuiteBase { CreateTopicCommand.createTopic(zkClient, topic, 1, 1, "0") logInfo("==================== 5 ====================") // wait until metadata is propagated - waitUntilMetadataIsPropagated(Seq(server), topic, 0, 1000) + waitUntilMetadataIsPropagated(topic, 0) } def produceAndSendMessage(topic: String, sent: Map[String, Int]) { - val brokerAddr = brokerConf.hostName + ":" + brokerConf.port - producer = new Producer[String, String](new ProducerConfig(getProducerConfig(brokerAddr))) + producer = new Producer[String, String](new ProducerConfig(getProducerConfig())) producer.send(createTestMessage(topic, sent): _*) + producer.close() logInfo("==================== 6 ====================") } -} - -object KafkaTestUtils { - val random = new Random() - def getBrokerConfig(port: Int, zkConnect: String): Properties = { + private def getBrokerConfig(): Properties = { val props = new Properties() props.put("broker.id", "0") props.put("host.name", "localhost") - props.put("port", port.toString) + props.put("port", brokerPort.toString) props.put("log.dir", Utils.createTempDir().getAbsolutePath) - props.put("zookeeper.connect", zkConnect) + props.put("zookeeper.connect", zkAddress) props.put("log.flush.interval.messages", "1") props.put("replica.socket.timeout.ms", "1500") props } - def getProducerConfig(brokerList: String): Properties = { + private def getProducerConfig(): Properties = { + val brokerAddr = brokerConf.hostName + ":" + brokerConf.port val props = new Properties() - props.put("metadata.broker.list", brokerList) + props.put("metadata.broker.list", brokerAddr) props.put("serializer.class", classOf[StringEncoder].getName) props } - def waitUntilTrue(condition: () => Boolean, waitTime: Long): Boolean = { - val startTime = System.currentTimeMillis() - while (true) { - if (condition()) - return true - if (System.currentTimeMillis() > startTime + waitTime) - return false - Thread.sleep(waitTime.min(100L)) + private def waitUntilMetadataIsPropagated(topic: String, partition: Int) { + eventually(timeout(1000 milliseconds), interval(100 milliseconds)) { + assert( + server.apis.leaderCache.keySet.contains(TopicAndPartition(topic, partition)), + s"Partition [$topic, $partition] metadata not propagated after timeout" + ) } - // Should never go to here - throw new RuntimeException("unexpected error") - } - - def waitUntilMetadataIsPropagated(servers: Seq[KafkaServer], topic: String, partition: Int, - timeout: Long) { - assert(waitUntilTrue(() => - servers.foldLeft(true)(_ && _.apis.leaderCache.keySet.contains( - TopicAndPartition(topic, partition))), timeout), - s"Partition [$topic, $partition] metadata not propagated after timeout") } class EmbeddedZookeeper(val zkConnect: String) { @@ -227,3 +195,53 @@ object KafkaTestUtils { } } } + + +class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter { + var ssc: StreamingContext = _ + + before { + setupKafka() + } + + after { + if (ssc != null) { + ssc.stop() + ssc = null + } + tearDownKafka() + } + + test("Kafka input stream") { + val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName) + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + val topic = "topic1" + val sent = Map("a" -> 5, "b" -> 3, "c" -> 10) + createTopic(topic) + produceAndSendMessage(topic, sent) + + val kafkaParams = Map("zookeeper.connect" -> zkAddress, + "group.id" -> s"test-consumer-${Random.nextInt(10000)}", + "auto.offset.reset" -> "smallest") + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map(_._2).countByValue().foreachRDD { r => + val ret = r.collect() + ret.toMap.foreach { kv => + val count = result.getOrElseUpdate(kv._1, 0) + kv._2 + result.put(kv._1, count) + } + } + ssc.start() + eventually(timeout(10000 milliseconds), interval(100 milliseconds)) { + assert(sent.size === result.size) + sent.keys.foreach { k => + assert(sent(k) === result(k).toInt) + } + } + ssc.stop() + } +} + diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala new file mode 100644 index 0000000000000..64ccc92c81fa9 --- /dev/null +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kafka + + +import java.io.File + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps +import scala.util.Random + +import com.google.common.io.Files +import kafka.serializer.StringDecoder +import kafka.utils.{ZKGroupTopicDirs, ZkUtils} +import org.apache.commons.io.FileUtils +import org.scalatest.BeforeAndAfter +import org.scalatest.concurrent.Eventually + +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Milliseconds, StreamingContext} + +class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually { + + val sparkConf = new SparkConf() + .setMaster("local[4]") + .setAppName(this.getClass.getSimpleName) + .set("spark.streaming.receiver.writeAheadLog.enable", "true") + val data = Map("a" -> 10, "b" -> 10, "c" -> 10) + + + var groupId: String = _ + var kafkaParams: Map[String, String] = _ + var ssc: StreamingContext = _ + var tempDirectory: File = null + + before { + setupKafka() + groupId = s"test-consumer-${Random.nextInt(10000)}" + kafkaParams = Map( + "zookeeper.connect" -> zkAddress, + "group.id" -> groupId, + "auto.offset.reset" -> "smallest" + ) + + ssc = new StreamingContext(sparkConf, Milliseconds(500)) + tempDirectory = Files.createTempDir() + ssc.checkpoint(tempDirectory.getAbsolutePath) + } + + after { + if (ssc != null) { + ssc.stop() + } + if (tempDirectory != null && tempDirectory.exists()) { + FileUtils.deleteDirectory(tempDirectory) + tempDirectory = null + } + tearDownKafka() + } + + + test("Reliable Kafka input stream with single topic") { + var topic = "test-topic" + createTopic(topic) + produceAndSendMessage(topic, data) + + // Verify whether the offset of this group/topic/partition is 0 before starting. + assert(getCommitOffset(groupId, topic, 0) === None) + + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, Map(topic -> 1), StorageLevel.MEMORY_ONLY) + val result = new mutable.HashMap[String, Long]() + stream.map { case (k, v) => v }.foreachRDD { r => + val ret = r.collect() + ret.foreach { v => + val count = result.getOrElseUpdate(v, 0) + 1 + result.put(v, count) + } + } + ssc.start() + eventually(timeout(20000 milliseconds), interval(200 milliseconds)) { + // A basic process verification for ReliableKafkaReceiver. + // Verify whether received message number is equal to the sent message number. + assert(data.size === result.size) + // Verify whether each message is the same as the data to be verified. + data.keys.foreach { k => assert(data(k) === result(k).toInt) } + // Verify the offset number whether it is equal to the total message number. + assert(getCommitOffset(groupId, topic, 0) === Some(29L)) + } + ssc.stop() + } + + test("Reliable Kafka input stream with multiple topics") { + val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1) + topics.foreach { case (t, _) => + createTopic(t) + produceAndSendMessage(t, data) + } + + // Before started, verify all the group/topic/partition offsets are 0. + topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === None) } + + // Consuming all the data sent to the broker which will potential commit the offsets internally. + val stream = KafkaUtils.createStream[String, String, StringDecoder, StringDecoder]( + ssc, kafkaParams, topics, StorageLevel.MEMORY_ONLY) + stream.foreachRDD(_ => Unit) + ssc.start() + eventually(timeout(20000 milliseconds), interval(100 milliseconds)) { + // Verify the offset for each group/topic to see whether they are equal to the expected one. + topics.foreach { case (t, _) => assert(getCommitOffset(groupId, t, 0) === Some(29L)) } + } + ssc.stop() + } + + + /** Getting partition offset from Zookeeper. */ + private def getCommitOffset(groupId: String, topic: String, partition: Int): Option[Long] = { + assert(zkClient != null, "Zookeeper client is not initialized") + val topicDirs = new ZKGroupTopicDirs(groupId, topic) + val zkPath = s"${topicDirs.consumerOffsetDir}/$partition" + ZkUtils.readDataMaybeNull(zkClient, zkPath)._1.map(_.toLong) + } +} diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 371f1f1e9d39a..703806735b3ff 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,24 +39,13 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided org.eclipse.paho mqtt-client 0.4.0 - - ${akka.group} - akka-zeromq_${scala.binary.version} - ${akka.version} - org.scalatest scalatest_${scala.binary.version} diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 467fd263e2d64..84595acf45ccb 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -17,11 +17,19 @@ package org.apache.spark.streaming.mqtt -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.scalatest.FunSuite + +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -class MQTTStreamSuite extends TestSuiteBase { +class MQTTStreamSuite extends FunSuite { + + val batchDuration = Seconds(1) + + private val master: String = "local[2]" + + private val framework: String = this.getClass.getSimpleName test("mqtt input stream") { val ssc = new StreamingContext(master, framework, batchDuration) diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 1d7dd49d15c22..000ace1446e5e 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided org.twitter4j diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index 93741e0375164..9ee57d7581d85 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -17,13 +17,23 @@ package org.apache.spark.streaming.twitter -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} -import org.apache.spark.storage.StorageLevel + +import org.scalatest.{BeforeAndAfter, FunSuite} +import twitter4j.Status import twitter4j.auth.{NullAuthorization, Authorization} + +import org.apache.spark.Logging +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.ReceiverInputDStream -import twitter4j.Status -class TwitterStreamSuite extends TestSuiteBase { +class TwitterStreamSuite extends FunSuite with BeforeAndAfter with Logging { + + val batchDuration = Seconds(1) + + private val master: String = "local[2]" + + private val framework: String = this.getClass.getSimpleName test("twitter input stream") { val ssc = new StreamingContext(master, framework, batchDuration) diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 7e48968feb3bc..29c452093502e 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml @@ -39,13 +39,7 @@ org.apache.spark spark-streaming_${scala.binary.version} ${project.version} - - - org.apache.spark - spark-streaming_${scala.binary.version} - ${project.version} - test-jar - test + provided ${akka.group} diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java new file mode 100644 index 0000000000000..6e1f01900071b --- /dev/null +++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.junit.After; +import org.junit.Before; + +public abstract class LocalJavaStreamingContext { + + protected transient JavaStreamingContext ssc; + + @Before + public void setUp() { + System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock"); + ssc = new JavaStreamingContext("local[2]", "test", new Duration(1000)); + ssc.checkpoint("checkpoint"); + } + + @After + public void tearDown() { + ssc.stop(); + ssc = null; + } +} diff --git a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala index cc10ff6ae03cd..a7566e733d891 100644 --- a/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala +++ b/external/zeromq/src/test/scala/org/apache/spark/streaming/zeromq/ZeroMQStreamSuite.scala @@ -20,12 +20,19 @@ package org.apache.spark.streaming.zeromq import akka.actor.SupervisorStrategy import akka.util.ByteString import akka.zeromq.Subscribe +import org.scalatest.FunSuite import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream -class ZeroMQStreamSuite extends TestSuiteBase { +class ZeroMQStreamSuite extends FunSuite { + + val batchDuration = Seconds(1) + + private val master: String = "local[2]" + + private val framework: String = this.getClass.getSimpleName test("zeromq input stream") { val ssc = new StreamingContext(master, framework, batchDuration) diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 7e478bed62da7..c8477a6566311 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 560244ad93369..c0d3a61119113 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 71a078d58a8d8..d1427f6a0c6e9 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 3f49b1d63b6e1..9982b36f9b62f 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala new file mode 100644 index 0000000000000..f70715fca6eea --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeContext.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +/** + * Represents an edge along with its neighboring vertices and allows sending messages along the + * edge. Used in [[Graph#aggregateMessages]]. + */ +abstract class EdgeContext[VD, ED, A] { + /** The vertex id of the edge's source vertex. */ + def srcId: VertexId + /** The vertex id of the edge's destination vertex. */ + def dstId: VertexId + /** The vertex attribute of the edge's source vertex. */ + def srcAttr: VD + /** The vertex attribute of the edge's destination vertex. */ + def dstAttr: VD + /** The attribute associated with the edge. */ + def attr: ED + + /** Sends a message to the source vertex. */ + def sendToSrc(msg: A): Unit + /** Sends a message to the destination vertex. */ + def sendToDst(msg: A): Unit + + /** Converts the edge and vertex properties into an [[EdgeTriplet]] for convenience. */ + def toEdgeTriplet: EdgeTriplet[VD, ED] = { + val et = new EdgeTriplet[VD, ED] + et.srcId = srcId + et.srcAttr = srcAttr + et.dstId = dstId + et.dstAttr = dstAttr + et.attr = attr + et + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 5bcb96b136ed7..cc70b396a8dd4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -17,14 +17,19 @@ package org.apache.spark.graphx -import scala.reflect.{classTag, ClassTag} +import scala.language.existentials +import scala.reflect.ClassTag -import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.Dependency +import org.apache.spark.Partition +import org.apache.spark.SparkContext +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.graphx.impl.EdgePartitionBuilder +import org.apache.spark.graphx.impl.EdgeRDDImpl /** * `EdgeRDD[ED, VD]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each @@ -32,33 +37,16 @@ import org.apache.spark.graphx.impl.EdgePartitionBuilder * edge to provide the triplet view. Shipping of the vertex attributes is managed by * `impl.ReplicatedVertexView`. */ -class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( - val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], - val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) - extends RDD[Edge[ED]](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { +abstract class EdgeRDD[ED]( + @transient sc: SparkContext, + @transient deps: Seq[Dependency[_]]) extends RDD[Edge[ED]](sc, deps) { - override def setName(_name: String): this.type = { - if (partitionsRDD.name != null) { - partitionsRDD.setName(partitionsRDD.name + ", " + _name) - } else { - partitionsRDD.setName(_name) - } - this - } - setName("EdgeRDD") + private[graphx] def partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])] forSome { type VD } override protected def getPartitions: Array[Partition] = partitionsRDD.partitions - /** - * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the - * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new - * partitioner that allows co-partitioning with `partitionsRDD`. - */ - override val partitioner = - partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) - override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { - val p = firstParent[(PartitionID, EdgePartition[ED, VD])].iterator(part, context) + val p = firstParent[(PartitionID, EdgePartition[ED, _])].iterator(part, context) if (p.hasNext) { p.next._2.iterator.map(_.copy()) } else { @@ -66,40 +54,6 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( } } - override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() - - /** - * Persists the edge partitions at the specified storage level, ignoring any existing target - * storage level. - */ - override def persist(newLevel: StorageLevel): this.type = { - partitionsRDD.persist(newLevel) - this - } - - override def unpersist(blocking: Boolean = true): this.type = { - partitionsRDD.unpersist(blocking) - this - } - - /** Persists the vertex partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ - override def cache(): this.type = { - partitionsRDD.persist(targetStorageLevel) - this - } - - private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( - f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = { - this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => - if (iter.hasNext) { - val (pid, ep) = iter.next() - Iterator(Tuple2(pid, f(pid, ep))) - } else { - Iterator.empty - } - }, preservesPartitioning = true)) - } - /** * Map the values in an edge partitioning preserving the structure but changing the values. * @@ -107,22 +61,14 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( * @param f the function from an edge to a new edge value * @return a new EdgeRDD containing the new edge values */ - def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2, VD] = - mapEdgePartitions((pid, part) => part.map(f)) + def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDD[ED2] /** * Reverse all the edges in this RDD. * * @return a new EdgeRDD containing all the edges reversed */ - def reverse: EdgeRDD[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) - - /** Removes all edges but those matching `epred` and where both vertices match `vpred`. */ - def filter( - epred: EdgeTriplet[VD, ED] => Boolean, - vpred: (VertexId, VD) => Boolean): EdgeRDD[ED, VD] = { - mapEdgePartitions((pid, part) => part.filter(epred, vpred)) - } + def reverse: EdgeRDD[ED] /** * Inner joins this EdgeRDD with another EdgeRDD, assuming both are partitioned using the same @@ -134,23 +80,8 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( * with values supplied by `f` */ def innerJoin[ED2: ClassTag, ED3: ClassTag] - (other: EdgeRDD[ED2, _]) - (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3, VD] = { - val ed2Tag = classTag[ED2] - val ed3Tag = classTag[ED3] - this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) { - (thisIter, otherIter) => - val (pid, thisEPart) = thisIter.next() - val (_, otherEPart) = otherIter.next() - Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag))) - }) - } - - /** Replaces the vertex partitions while preserving all other properties of the VertexRDD. */ - private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( - partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDD[ED2, VD2] = { - new EdgeRDD(partitionsRDD, this.targetStorageLevel) - } + (other: EdgeRDD[ED2]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDD[ED3] /** * Changes the target storage level while preserving all other properties of the @@ -159,11 +90,7 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( * This does not actually trigger a cache; to do this, call * [[org.apache.spark.graphx.EdgeRDD#cache]] on the returned EdgeRDD. */ - private[graphx] def withTargetStorageLevel( - targetStorageLevel: StorageLevel): EdgeRDD[ED, VD] = { - new EdgeRDD(this.partitionsRDD, targetStorageLevel) - } - + private[graphx] def withTargetStorageLevel(targetStorageLevel: StorageLevel): EdgeRDD[ED] } object EdgeRDD { @@ -173,7 +100,7 @@ object EdgeRDD { * @tparam ED the edge attribute type * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD */ - def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDD[ED, VD] = { + def fromEdges[ED: ClassTag, VD: ClassTag](edges: RDD[Edge[ED]]): EdgeRDDImpl[ED, VD] = { val edgePartitions = edges.mapPartitionsWithIndex { (pid, iter) => val builder = new EdgePartitionBuilder[ED, VD] iter.foreach { e => @@ -190,8 +117,8 @@ object EdgeRDD { * @tparam ED the edge attribute type * @tparam VD the type of the vertex attributes that may be joined with the returned EdgeRDD */ - def fromEdgePartitions[ED: ClassTag, VD: ClassTag]( - edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDD[ED, VD] = { - new EdgeRDD(edgePartitions) + private[graphx] def fromEdgePartitions[ED: ClassTag, VD: ClassTag]( + edgePartitions: RDD[(Int, EdgePartition[ED, VD])]): EdgeRDDImpl[ED, VD] = { + new EdgeRDDImpl(edgePartitions) } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index fa4b891754c40..637791543514c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -59,7 +59,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * along with their vertex data. * */ - @transient val edges: EdgeRDD[ED, VD] + @transient val edges: EdgeRDD[ED] /** * An RDD containing the edge triplets, which are edges along with the vertex data associated with @@ -208,7 +208,37 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * */ def mapTriplets[ED2: ClassTag](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] = { - mapTriplets((pid, iter) => iter.map(map)) + mapTriplets((pid, iter) => iter.map(map), TripletFields.All) + } + + /** + * Transforms each edge attribute using the map function, passing it the adjacent vertex + * attributes as well. If adjacent vertex values are not required, + * consider using `mapEdges` instead. + * + * @note This does not change the structure of the + * graph or modify the values of this graph. As a consequence + * the underlying index structures can be reused. + * + * @param map the function from an edge object to a new edge value. + * @param tripletFields which fields should be included in the edge triplet passed to the map + * function. If not all fields are needed, specifying this can improve performance. + * + * @tparam ED2 the new edge data type + * + * @example This function might be used to initialize edge + * attributes based on the attributes associated with each vertex. + * {{{ + * val rawGraph: Graph[Int, Int] = someLoadFunction() + * val graph = rawGraph.mapTriplets[Int]( edge => + * edge.src.data - edge.dst.data) + * }}} + * + */ + def mapTriplets[ED2: ClassTag]( + map: EdgeTriplet[VD, ED] => ED2, + tripletFields: TripletFields): Graph[VD, ED2] = { + mapTriplets((pid, iter) => iter.map(map), tripletFields) } /** @@ -223,12 +253,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * the underlying index structures can be reused. * * @param map the iterator transform + * @param tripletFields which fields should be included in the edge triplet passed to the map + * function. If not all fields are needed, specifying this can improve performance. * * @tparam ED2 the new edge data type * */ - def mapTriplets[ED2: ClassTag](map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]) - : Graph[VD, ED2] + def mapTriplets[ED2: ClassTag]( + map: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] /** * Reverses all edges in the graph. If this graph contains an edge from a to b then the returned @@ -287,6 +320,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * "sent" to either vertex in the edge. The `reduceFunc` is then used to combine the output of * the map phase destined to each vertex. * + * This function is deprecated in 1.2.0 because of SPARK-3936. Use aggregateMessages instead. + * * @tparam A the type of "message" to be sent to each vertex * * @param mapFunc the user defined map function which returns 0 or @@ -296,13 +331,15 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * be commutative and associative and is used to combine the output * of the map phase * - * @param activeSetOpt optionally, a set of "active" vertices and a direction of edges to - * consider when running `mapFunc`. If the direction is `In`, `mapFunc` will only be run on - * edges with destination in the active set. If the direction is `Out`, - * `mapFunc` will only be run on edges originating from vertices in the active set. If the - * direction is `Either`, `mapFunc` will be run on edges with *either* vertex in the active set - * . If the direction is `Both`, `mapFunc` will be run on edges with *both* vertices in the - * active set. The active set must have the same index as the graph's vertices. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run only on edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. * * @example We can use this function to compute the in-degree of each * vertex @@ -319,6 +356,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * predicate or implement PageRank. * */ + @deprecated("use aggregateMessages", "1.2.0") def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, @@ -326,8 +364,80 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab : VertexRDD[A] /** - * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. The - * input table should contain at most one entry for each vertex. If no entry in `other` is + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * + * @example We can use this function to compute the in-degree of each + * vertex + * {{{ + * val rawGraph: Graph[_, _] = Graph.textFile("twittergraph") + * val inDeg: RDD[(VertexId, Int)] = + * aggregateMessages[Int](ctx => ctx.sendToDst(1), _ + _) + * }}} + * + * @note By expressing computation at the edge level we achieve + * maximum parallelism. This is one of the core functions in the + * Graph API in that enables neighborhood level computation. For + * example this function can be used to count neighbors satisfying a + * predicate or implement PageRank. + * + */ + def aggregateMessages[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields = TripletFields.All) + : VertexRDD[A] = { + aggregateMessagesWithActiveSet(sendMsg, mergeMsg, tripletFields, None) + } + + /** + * Aggregates values from the neighboring edges and vertices of each vertex. The user-supplied + * `sendMsg` function is invoked on each edge of the graph, generating 0 or more messages to be + * sent to either vertex in the edge. The `mergeMsg` function is then used to combine all messages + * destined to the same vertex. + * + * This variant can take an active set to restrict the computation and is intended for internal + * use only. + * + * @tparam A the type of message to be sent to each vertex + * + * @param sendMsg runs on each edge, sending messages to neighboring vertices using the + * [[EdgeContext]]. + * @param mergeMsg used to combine messages from `sendMsg` destined to the same vertex. This + * combiner should be commutative and associative. + * @param tripletFields which fields should be included in the [[EdgeContext]] passed to the + * `sendMsg` function. If not all fields are needed, specifying this can improve performance. + * @param activeSetOpt an efficient way to run the aggregation on a subset of the edges if + * desired. This is done by specifying a set of "active" vertices and an edge direction. The + * `sendMsg` function will then run on only edges connected to active vertices by edges in the + * specified direction. If the direction is `In`, `sendMsg` will only be run on edges with + * destination in the active set. If the direction is `Out`, `sendMsg` will only be run on edges + * originating from vertices in the active set. If the direction is `Either`, `sendMsg` will be + * run on edges with *either* vertex in the active set. If the direction is `Both`, `sendMsg` + * will be run on edges with *both* vertices in the active set. The active set must have the + * same index as the graph's vertices. + */ + private[graphx] def aggregateMessagesWithActiveSet[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]) + : VertexRDD[A] + + /** + * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. + * The input table should contain at most one entry for each vertex. If no entry in `other` is * provided for a particular vertex in the graph, the map function receives `None`. * * @tparam U the type of entry in the table of updates diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala index 1948c978c30bf..563c948957ecf 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala @@ -27,10 +27,10 @@ import org.apache.spark.graphx.impl._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.util.collection.OpenHashSet - /** * Registers GraphX classes with Kryo for improved performance. */ +@deprecated("Register GraphX classes with Kryo using GraphXUtils.registerKryoClasses", "1.2.0") class GraphKryoRegistrator extends KryoRegistrator { def registerClasses(kryo: Kryo) { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index f4c79365b16da..4933aecba1286 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -48,7 +48,8 @@ object GraphLoader extends Logging { * @param path the path to the file (e.g., /home/data/file or hdfs://file) * @param canonicalOrientation whether to orient edges in the positive * direction - * @param minEdgePartitions the number of partitions for the edge RDD + * @param numEdgePartitions the number of partitions for the edge RDD + * Setting this value to -1 will use the default parallelism. * @param edgeStorageLevel the desired storage level for the edge partitions * @param vertexStorageLevel the desired storage level for the vertex partitions */ @@ -56,7 +57,7 @@ object GraphLoader extends Logging { sc: SparkContext, path: String, canonicalOrientation: Boolean = false, - minEdgePartitions: Int = 1, + numEdgePartitions: Int = -1, edgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY, vertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) : Graph[Int, Int] = @@ -64,7 +65,12 @@ object GraphLoader extends Logging { val startTime = System.currentTimeMillis // Parse the edge data table directly into edge partitions - val lines = sc.textFile(path, minEdgePartitions).coalesce(minEdgePartitions) + val lines = + if (numEdgePartitions > 0) { + sc.textFile(path, numEdgePartitions).coalesce(numEdgePartitions) + } else { + sc.textFile(path) + } val edges = lines.mapPartitionsWithIndex { (pid, iter) => val builder = new EdgePartitionBuilder[Int, Int] iter.foreach { line => diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index d0dd45dba618e..116d1ea700175 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -69,11 +69,12 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali */ private def degreesRDD(edgeDirection: EdgeDirection): VertexRDD[Int] = { if (edgeDirection == EdgeDirection.In) { - graph.mapReduceTriplets(et => Iterator((et.dstId,1)), _ + _) + graph.aggregateMessages(_.sendToDst(1), _ + _, TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { - graph.mapReduceTriplets(et => Iterator((et.srcId,1)), _ + _) + graph.aggregateMessages(_.sendToSrc(1), _ + _, TripletFields.None) } else { // EdgeDirection.Either - graph.mapReduceTriplets(et => Iterator((et.srcId,1), (et.dstId,1)), _ + _) + graph.aggregateMessages(ctx => { ctx.sendToSrc(1); ctx.sendToDst(1) }, _ + _, + TripletFields.None) } } @@ -88,18 +89,17 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexId]] = { val nbrs = if (edgeDirection == EdgeDirection.Either) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.srcId, Array(et.dstId)), (et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _ - ) + graph.aggregateMessages[Array[VertexId]]( + ctx => { ctx.sendToSrc(Array(ctx.dstId)); ctx.sendToDst(Array(ctx.srcId)) }, + _ ++ _, TripletFields.None) } else if (edgeDirection == EdgeDirection.Out) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.srcId, Array(et.dstId))), - reduceFunc = _ ++ _) + graph.aggregateMessages[Array[VertexId]]( + ctx => ctx.sendToSrc(Array(ctx.dstId)), + _ ++ _, TripletFields.None) } else if (edgeDirection == EdgeDirection.In) { - graph.mapReduceTriplets[Array[VertexId]]( - mapFunc = et => Iterator((et.dstId, Array(et.srcId))), - reduceFunc = _ ++ _) + graph.aggregateMessages[Array[VertexId]]( + ctx => ctx.sendToDst(Array(ctx.srcId)), + _ ++ _, TripletFields.None) } else { throw new SparkException("It doesn't make sense to collect neighbor ids without a " + "direction. (EdgeDirection.Both is not supported; use EdgeDirection.Either instead.)") @@ -122,22 +122,27 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * @return the vertex set of neighboring vertex attributes for each vertex */ def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexId, VD)]] = { - val nbrs = graph.mapReduceTriplets[Array[(VertexId,VD)]]( - edge => { - val msgToSrc = (edge.srcId, Array((edge.dstId, edge.dstAttr))) - val msgToDst = (edge.dstId, Array((edge.srcId, edge.srcAttr))) - edgeDirection match { - case EdgeDirection.Either => Iterator(msgToSrc, msgToDst) - case EdgeDirection.In => Iterator(msgToDst) - case EdgeDirection.Out => Iterator(msgToSrc) - case EdgeDirection.Both => - throw new SparkException("collectNeighbors does not support EdgeDirection.Both. Use" + - "EdgeDirection.Either instead.") - } - }, - (a, b) => a ++ b) - - graph.vertices.leftZipJoin(nbrs) { (vid, vdata, nbrsOpt) => + val nbrs = edgeDirection match { + case EdgeDirection.Either => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => { + ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))) + ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))) + }, + (a, b) => a ++ b, TripletFields.All) + case EdgeDirection.In => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => ctx.sendToDst(Array((ctx.srcId, ctx.srcAttr))), + (a, b) => a ++ b, TripletFields.Src) + case EdgeDirection.Out => + graph.aggregateMessages[Array[(VertexId,VD)]]( + ctx => ctx.sendToSrc(Array((ctx.dstId, ctx.dstAttr))), + (a, b) => a ++ b, TripletFields.Dst) + case EdgeDirection.Both => + throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + + "EdgeDirection.Either instead.") + } + graph.vertices.leftJoin(nbrs) { (vid, vdata, nbrsOpt) => nbrsOpt.getOrElse(Array.empty[(VertexId, VD)]) } } // end of collectNeighbor @@ -160,18 +165,20 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali def collectEdges(edgeDirection: EdgeDirection): VertexRDD[Array[Edge[ED]]] = { edgeDirection match { case EdgeDirection.Either => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr))), - (edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => { + ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))) + ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))) + }, + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.In => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.dstId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => ctx.sendToDst(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))), + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Out => - graph.mapReduceTriplets[Array[Edge[ED]]]( - edge => Iterator((edge.srcId, Array(new Edge(edge.srcId, edge.dstId, edge.attr)))), - (a, b) => a ++ b) + graph.aggregateMessages[Array[Edge[ED]]]( + ctx => ctx.sendToSrc(Array(new Edge(ctx.srcId, ctx.dstId, ctx.attr))), + (a, b) => a ++ b, TripletFields.EdgeOnly) case EdgeDirection.Both => throw new SparkException("collectEdges does not support EdgeDirection.Both. Use" + "EdgeDirection.Either instead.") diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala new file mode 100644 index 0000000000000..2cb07937eaa2a --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +import org.apache.spark.SparkConf + +import org.apache.spark.graphx.impl._ +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap + +import org.apache.spark.util.collection.{OpenHashSet, BitSet} +import org.apache.spark.util.BoundedPriorityQueue + +object GraphXUtils { + /** + * Registers classes that GraphX uses with Kryo. + */ + def registerKryoClasses(conf: SparkConf) { + conf.registerKryoClasses(Array( + classOf[Edge[Object]], + classOf[(VertexId, Object)], + classOf[EdgePartition[Object, Object]], + classOf[BitSet], + classOf[VertexIdToIndexMap], + classOf[VertexAttributeBlock[Object]], + classOf[PartitionStrategy], + classOf[BoundedPriorityQueue[Object]], + classOf[EdgeDirection], + classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]], + classOf[OpenHashSet[Int]], + classOf[OpenHashSet[Long]])) + } +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java new file mode 100644 index 0000000000000..7eb4ae0f44602 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/TripletFields.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx; + +import java.io.Serializable; + +/** + * Represents a subset of the fields of an [[EdgeTriplet]] or [[EdgeContext]]. This allows the + * system to populate only those fields for efficiency. + */ +public class TripletFields implements Serializable { + + /** Indicates whether the source vertex attribute is included. */ + public final boolean useSrc; + + /** Indicates whether the destination vertex attribute is included. */ + public final boolean useDst; + + /** Indicates whether the edge attribute is included. */ + public final boolean useEdge; + + /** Constructs a default TripletFields in which all fields are included. */ + public TripletFields() { + this(true, true, true); + } + + public TripletFields(boolean useSrc, boolean useDst, boolean useEdge) { + this.useSrc = useSrc; + this.useDst = useDst; + this.useEdge = useEdge; + } + + /** + * None of the triplet fields are exposed. + */ + public static final TripletFields None = new TripletFields(false, false, false); + + /** + * Expose only the edge field and not the source or destination field. + */ + public static final TripletFields EdgeOnly = new TripletFields(false, false, true); + + /** + * Expose the source and edge fields but not the destination field. (Same as Src) + */ + public static final TripletFields Src = new TripletFields(true, false, true); + + /** + * Expose the destination and edge fields but not the source field. (Same as Dst) + */ + public static final TripletFields Dst = new TripletFields(false, true, true); + + /** + * Expose all the fields (source, edge, and destination). + */ + public static final TripletFields All = new TripletFields(true, true, true); +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 2c8b245955d12..1db3df03c8052 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -27,8 +27,7 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx.impl.RoutingTablePartition import org.apache.spark.graphx.impl.ShippableVertexPartition import org.apache.spark.graphx.impl.VertexAttributeBlock -import org.apache.spark.graphx.impl.RoutingTableMessageRDDFunctions._ -import org.apache.spark.graphx.impl.VertexRDDFunctions._ +import org.apache.spark.graphx.impl.VertexRDDImpl /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by @@ -55,62 +54,16 @@ import org.apache.spark.graphx.impl.VertexRDDFunctions._ * * @tparam VD the vertex attribute associated with each vertex in the set. */ -class VertexRDD[@specialized VD: ClassTag]( - val partitionsRDD: RDD[ShippableVertexPartition[VD]], - val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) - extends RDD[(VertexId, VD)](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { +abstract class VertexRDD[VD]( + @transient sc: SparkContext, + @transient deps: Seq[Dependency[_]]) extends RDD[(VertexId, VD)](sc, deps) { - require(partitionsRDD.partitioner.isDefined) + implicit protected def vdTag: ClassTag[VD] - /** - * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting - * VertexRDD will be based on a different index and can no longer be quickly joined with this - * RDD. - */ - def reindex(): VertexRDD[VD] = this.withPartitionsRDD(partitionsRDD.map(_.reindex())) - - override val partitioner = partitionsRDD.partitioner + private[graphx] def partitionsRDD: RDD[ShippableVertexPartition[VD]] override protected def getPartitions: Array[Partition] = partitionsRDD.partitions - override protected def getPreferredLocations(s: Partition): Seq[String] = - partitionsRDD.preferredLocations(s) - - override def setName(_name: String): this.type = { - if (partitionsRDD.name != null) { - partitionsRDD.setName(partitionsRDD.name + ", " + _name) - } else { - partitionsRDD.setName(_name) - } - this - } - setName("VertexRDD") - - /** - * Persists the vertex partitions at the specified storage level, ignoring any existing target - * storage level. - */ - override def persist(newLevel: StorageLevel): this.type = { - partitionsRDD.persist(newLevel) - this - } - - override def unpersist(blocking: Boolean = true): this.type = { - partitionsRDD.unpersist(blocking) - this - } - - /** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */ - override def cache(): this.type = { - partitionsRDD.persist(targetStorageLevel) - this - } - - /** The number of vertices in the RDD. */ - override def count(): Long = { - partitionsRDD.map(_.size.toLong).reduce(_ + _) - } - /** * Provides the `RDD[(VertexId, VD)]` equivalent output. */ @@ -118,22 +71,28 @@ class VertexRDD[@specialized VD: ClassTag]( firstParent[ShippableVertexPartition[VD]].iterator(part, context).next.iterator } + /** + * Construct a new VertexRDD that is indexed by only the visible vertices. The resulting + * VertexRDD will be based on a different index and can no longer be quickly joined with this + * RDD. + */ + def reindex(): VertexRDD[VD] + /** * Applies a function to each `VertexPartition` of this RDD and returns a new VertexRDD. */ private[graphx] def mapVertexPartitions[VD2: ClassTag]( f: ShippableVertexPartition[VD] => ShippableVertexPartition[VD2]) - : VertexRDD[VD2] = { - val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true) - this.withPartitionsRDD(newPartitionsRDD) - } - + : VertexRDD[VD2] /** * Restricts the vertex set to the set of vertices satisfying the given predicate. This operation * preserves the index for efficient joins with the original RDD, and it sets bits in the bitmask * rather than allocating new memory. * + * It is declared and defined here to allow refining the return type from `RDD[(VertexId, VD)]` to + * `VertexRDD[VD]`. + * * @param pred the user defined predicate, which takes a tuple to conform to the * `RDD[(VertexId, VD)]` interface */ @@ -149,8 +108,7 @@ class VertexRDD[@specialized VD: ClassTag]( * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the * original VertexRDD */ - def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] = - this.mapVertexPartitions(_.map((vid, attr) => f(attr))) + def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] /** * Maps each vertex attribute, additionally supplying the vertex ID. @@ -161,23 +119,13 @@ class VertexRDD[@specialized VD: ClassTag]( * @return a new VertexRDD with values obtained by applying `f` to each of the entries in the * original VertexRDD. The resulting VertexRDD retains the same index. */ - def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = - this.mapVertexPartitions(_.map(f)) + def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] /** * Hides vertices that are the same between `this` and `other`; for vertices that are different, * keeps the values from `other`. */ - def diff(other: VertexRDD[VD]): VertexRDD[VD] = { - val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true - ) { (thisIter, otherIter) => - val thisPart = thisIter.next() - val otherPart = otherIter.next() - Iterator(thisPart.diff(otherPart)) - } - this.withPartitionsRDD(newPartitionsRDD) - } + def diff(other: VertexRDD[VD]): VertexRDD[VD] /** * Left joins this RDD with another VertexRDD with the same index. This function will fail if @@ -194,16 +142,7 @@ class VertexRDD[@specialized VD: ClassTag]( * @return a VertexRDD containing the results of `f` */ def leftZipJoin[VD2: ClassTag, VD3: ClassTag] - (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] = { - val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true - ) { (thisIter, otherIter) => - val thisPart = thisIter.next() - val otherPart = otherIter.next() - Iterator(thisPart.leftJoin(otherPart)(f)) - } - this.withPartitionsRDD(newPartitionsRDD) - } + (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] /** * Left joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is @@ -224,37 +163,14 @@ class VertexRDD[@specialized VD: ClassTag]( def leftJoin[VD2: ClassTag, VD3: ClassTag] (other: RDD[(VertexId, VD2)]) (f: (VertexId, VD, Option[VD2]) => VD3) - : VertexRDD[VD3] = { - // Test if the other vertex is a VertexRDD to choose the optimal join strategy. - // If the other set is a VertexRDD then we use the much more efficient leftZipJoin - other match { - case other: VertexRDD[_] => - leftZipJoin(other)(f) - case _ => - this.withPartitionsRDD[VD3]( - partitionsRDD.zipPartitions( - other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { - (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f)) - } - ) - } - } + : VertexRDD[VD3] /** * Efficiently inner joins this VertexRDD with another VertexRDD sharing the same index. See * [[innerJoin]] for the behavior of the join. */ def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U]) - (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { - val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true - ) { (thisIter, otherIter) => - val thisPart = thisIter.next() - val otherPart = otherIter.next() - Iterator(thisPart.innerJoin(otherPart)(f)) - } - this.withPartitionsRDD(newPartitionsRDD) - } + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] /** * Inner joins this VertexRDD with an RDD containing vertex attribute pairs. If the other RDD is @@ -268,21 +184,7 @@ class VertexRDD[@specialized VD: ClassTag]( * `this` and `other`, with values supplied by `f` */ def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) - (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { - // Test if the other vertex is a VertexRDD to choose the optimal join strategy. - // If the other set is a VertexRDD then we use the much more efficient innerZipJoin - other match { - case other: VertexRDD[_] => - innerZipJoin(other)(f) - case _ => - this.withPartitionsRDD( - partitionsRDD.zipPartitions( - other.copartitionWithVertices(this.partitioner.get), preservesPartitioning = true) { - (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f)) - } - ) - } - } + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] /** * Aggregates vertices in `messages` that have the same ids using `reduceFunc`, returning a @@ -296,38 +198,20 @@ class VertexRDD[@specialized VD: ClassTag]( * messages. */ def aggregateUsingIndex[VD2: ClassTag]( - messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = { - val shuffled = messages.copartitionWithVertices(this.partitioner.get) - val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) => - thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc)) - } - this.withPartitionsRDD[VD2](parts) - } + messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] /** * Returns a new `VertexRDD` reflecting a reversal of all edge directions in the corresponding * [[EdgeRDD]]. */ - def reverseRoutingTables(): VertexRDD[VD] = - this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse)) + def reverseRoutingTables(): VertexRDD[VD] /** Prepares this VertexRDD for efficient joins with the given EdgeRDD. */ - def withEdges(edges: EdgeRDD[_, _]): VertexRDD[VD] = { - val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get) - val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) { - (partIter, routingTableIter) => - val routingTable = - if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty - partIter.map(_.withRoutingTable(routingTable)) - } - this.withPartitionsRDD(vertexPartitions) - } + def withEdges(edges: EdgeRDD[_]): VertexRDD[VD] /** Replaces the vertex partitions while preserving all other properties of the VertexRDD. */ private[graphx] def withPartitionsRDD[VD2: ClassTag]( - partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] = { - new VertexRDD(partitionsRDD, this.targetStorageLevel) - } + partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] /** * Changes the target storage level while preserving all other properties of the @@ -337,20 +221,14 @@ class VertexRDD[@specialized VD: ClassTag]( * [[org.apache.spark.graphx.VertexRDD#cache]] on the returned VertexRDD. */ private[graphx] def withTargetStorageLevel( - targetStorageLevel: StorageLevel): VertexRDD[VD] = { - new VertexRDD(this.partitionsRDD, targetStorageLevel) - } + targetStorageLevel: StorageLevel): VertexRDD[VD] /** Generates an RDD of vertex attributes suitable for shipping to the edge partitions. */ private[graphx] def shipVertexAttributes( - shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = { - partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst))) - } + shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] /** Generates an RDD of vertex IDs suitable for shipping to the edge partitions. */ - private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = { - partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds())) - } + private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] } // end of VertexRDD @@ -371,12 +249,12 @@ object VertexRDD { def apply[VD: ClassTag](vertices: RDD[(VertexId, VD)]): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) } val vertexPartitions = vPartitioned.mapPartitions( iter => Iterator(ShippableVertexPartition(iter)), preservesPartitioning = true) - new VertexRDD(vertexPartitions) + new VertexRDDImpl(vertexPartitions) } /** @@ -391,7 +269,7 @@ object VertexRDD { * @param defaultVal the vertex attribute to use when creating missing vertices */ def apply[VD: ClassTag]( - vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD): VertexRDD[VD] = { + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD): VertexRDD[VD] = { VertexRDD(vertices, edges, defaultVal, (a, b) => a) } @@ -408,11 +286,11 @@ object VertexRDD { * @param mergeFunc the commutative, associative duplicate vertex attribute merge function */ def apply[VD: ClassTag]( - vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_, _], defaultVal: VD, mergeFunc: (VD, VD) => VD + vertices: RDD[(VertexId, VD)], edges: EdgeRDD[_], defaultVal: VD, mergeFunc: (VD, VD) => VD ): VertexRDD[VD] = { val vPartitioned: RDD[(VertexId, VD)] = vertices.partitioner match { case Some(p) => vertices - case None => vertices.copartitionWithVertices(new HashPartitioner(vertices.partitions.size)) + case None => vertices.partitionBy(new HashPartitioner(vertices.partitions.size)) } val routingTables = createRoutingTables(edges, vPartitioned.partitioner.get) val vertexPartitions = vPartitioned.zipPartitions(routingTables, preservesPartitioning = true) { @@ -421,7 +299,7 @@ object VertexRDD { if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty Iterator(ShippableVertexPartition(vertexIter, routingTable, defaultVal, mergeFunc)) } - new VertexRDD(vertexPartitions) + new VertexRDDImpl(vertexPartitions) } /** @@ -436,25 +314,25 @@ object VertexRDD { * @param defaultVal the vertex attribute to use when creating missing vertices */ def fromEdges[VD: ClassTag]( - edges: EdgeRDD[_, _], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = { + edges: EdgeRDD[_], numPartitions: Int, defaultVal: VD): VertexRDD[VD] = { val routingTables = createRoutingTables(edges, new HashPartitioner(numPartitions)) val vertexPartitions = routingTables.mapPartitions({ routingTableIter => val routingTable = if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty Iterator(ShippableVertexPartition(Iterator.empty, routingTable, defaultVal)) }, preservesPartitioning = true) - new VertexRDD(vertexPartitions) + new VertexRDDImpl(vertexPartitions) } - private def createRoutingTables( - edges: EdgeRDD[_, _], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = { + private[graphx] def createRoutingTables( + edges: EdgeRDD[_], vertexPartitioner: Partitioner): RDD[RoutingTablePartition] = { // Determine which vertices each edge partition needs by creating a mapping from vid to pid. val vid2pid = edges.partitionsRDD.mapPartitions(_.flatMap( Function.tupled(RoutingTablePartition.edgePartitionToMsgs))) .setName("VertexRDD.createRoutingTables - vid2pid (aggregation)") val numEdgePartitions = edges.partitions.size - vid2pid.copartitionWithVertices(vertexPartitioner).mapPartitions( + vid2pid.partitionBy(vertexPartitioner).mapPartitions( iter => Iterator(RoutingTablePartition.fromMsgs(numEdgePartitions, iter)), preservesPartitioning = true) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java new file mode 100644 index 0000000000000..377ae849f045c --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeActiveness.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl; + +/** + * Criteria for filtering edges based on activeness. For internal use only. + */ +public enum EdgeActiveness { + /** Neither the source vertex nor the destination vertex need be active. */ + Neither, + /** The source vertex must be active. */ + SrcOnly, + /** The destination vertex must be active. */ + DstOnly, + /** Both vertices must be active. */ + Both, + /** At least one vertex must be active. */ + Either +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index a5c9cd1f8b4e6..373af75448374 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -21,63 +21,94 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet /** - * A collection of edges stored in columnar format, along with any vertex attributes referenced. The - * edges are stored in 3 large columnar arrays (src, dst, attribute). The arrays are clustered by - * src. There is an optional active vertex set for filtering computation on the edges. + * A collection of edges, along with referenced vertex attributes and an optional active vertex set + * for filtering computation on the edges. + * + * The edges are stored in columnar format in `localSrcIds`, `localDstIds`, and `data`. All + * referenced global vertex ids are mapped to a compact set of local vertex ids according to the + * `global2local` map. Each local vertex id is a valid index into `vertexAttrs`, which stores the + * corresponding vertex attribute, and `local2global`, which stores the reverse mapping to global + * vertex id. The global vertex ids that are active are optionally stored in `activeSet`. + * + * The edges are clustered by source vertex id, and the mapping from global vertex id to the index + * of the corresponding edge cluster is stored in `index`. * * @tparam ED the edge attribute type * @tparam VD the vertex attribute type * - * @param srcIds the source vertex id of each edge - * @param dstIds the destination vertex id of each edge + * @param localSrcIds the local source vertex id of each edge as an index into `local2global` and + * `vertexAttrs` + * @param localDstIds the local destination vertex id of each edge as an index into `local2global` + * and `vertexAttrs` * @param data the attribute associated with each edge - * @param index a clustered index on source vertex id - * @param vertices a map from referenced vertex ids to their corresponding attributes. Must - * contain all vertex ids from `srcIds` and `dstIds`, though not necessarily valid attributes for - * those vertex ids. The mask is not used. + * @param index a clustered index on source vertex id as a map from each global source vertex id to + * the offset in the edge arrays where the cluster for that vertex id begins + * @param global2local a map from referenced vertex ids to local ids which index into vertexAttrs + * @param local2global an array of global vertex ids where the offsets are local vertex ids + * @param vertexAttrs an array of vertex attributes where the offsets are local vertex ids * @param activeSet an optional active vertex set for filtering computation on the edges */ private[graphx] class EdgePartition[ @specialized(Char, Int, Boolean, Byte, Long, Float, Double) ED: ClassTag, VD: ClassTag]( - val srcIds: Array[VertexId] = null, - val dstIds: Array[VertexId] = null, - val data: Array[ED] = null, - val index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int] = null, - val vertices: VertexPartition[VD] = null, - val activeSet: Option[VertexSet] = None - ) extends Serializable { + localSrcIds: Array[Int], + localDstIds: Array[Int], + data: Array[ED], + index: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + local2global: Array[VertexId], + vertexAttrs: Array[VD], + activeSet: Option[VertexSet]) + extends Serializable { - /** Return a new `EdgePartition` with the specified edge data. */ - def withData[ED2: ClassTag](data_ : Array[ED2]): EdgePartition[ED2, VD] = { - new EdgePartition(srcIds, dstIds, data_, index, vertices, activeSet) - } + /** No-arg constructor for serialization. */ + private def this() = this(null, null, null, null, null, null, null, null) - /** Return a new `EdgePartition` with the specified vertex partition. */ - def withVertices[VD2: ClassTag]( - vertices_ : VertexPartition[VD2]): EdgePartition[ED, VD2] = { - new EdgePartition(srcIds, dstIds, data, index, vertices_, activeSet) + /** Return a new `EdgePartition` with the specified edge data. */ + def withData[ED2: ClassTag](data: Array[ED2]): EdgePartition[ED2, VD] = { + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) } /** Return a new `EdgePartition` with the specified active set, provided as an iterator. */ def withActiveSet(iter: Iterator[VertexId]): EdgePartition[ED, VD] = { - val newActiveSet = new VertexSet - iter.foreach(newActiveSet.add(_)) - new EdgePartition(srcIds, dstIds, data, index, vertices, Some(newActiveSet)) - } - - /** Return a new `EdgePartition` with the specified active set. */ - def withActiveSet(activeSet_ : Option[VertexSet]): EdgePartition[ED, VD] = { - new EdgePartition(srcIds, dstIds, data, index, vertices, activeSet_) + val activeSet = new VertexSet + while (iter.hasNext) { activeSet.add(iter.next()) } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, + Some(activeSet)) } /** Return a new `EdgePartition` with updates to vertex attributes specified in `iter`. */ def updateVertices(iter: Iterator[(VertexId, VD)]): EdgePartition[ED, VD] = { - this.withVertices(vertices.innerJoinKeepLeft(iter)) + val newVertexAttrs = new Array[VD](vertexAttrs.length) + System.arraycopy(vertexAttrs, 0, newVertexAttrs, 0, vertexAttrs.length) + while (iter.hasNext) { + val kv = iter.next() + newVertexAttrs(global2local(kv._1)) = kv._2 + } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, + activeSet) } + /** Return a new `EdgePartition` without any locally cached vertex attributes. */ + def withoutVertexAttributes[VD2: ClassTag](): EdgePartition[ED, VD2] = { + val newVertexAttrs = new Array[VD2](vertexAttrs.length) + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, newVertexAttrs, + activeSet) + } + + @inline private def srcIds(pos: Int): VertexId = local2global(localSrcIds(pos)) + + @inline private def dstIds(pos: Int): VertexId = local2global(localDstIds(pos)) + + @inline private def attrs(pos: Int): ED = data(pos) + /** Look up vid in activeSet, throwing an exception if it is None. */ def isActive(vid: VertexId): Boolean = { activeSet.get.contains(vid) @@ -92,11 +123,19 @@ class EdgePartition[ * @return a new edge partition with all edges reversed. */ def reverse: EdgePartition[ED, VD] = { - val builder = new EdgePartitionBuilder(size)(classTag[ED], classTag[VD]) - for (e <- iterator) { - builder.add(e.dstId, e.srcId, e.attr) + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet, size) + var i = 0 + while (i < size) { + val localSrcId = localSrcIds(i) + val localDstId = localDstIds(i) + val srcId = local2global(localSrcId) + val dstId = local2global(localDstId) + val attr = data(i) + builder.add(dstId, srcId, localDstId, localSrcId, attr) + i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -157,13 +196,25 @@ class EdgePartition[ def filter( epred: EdgeTriplet[VD, ED] => Boolean, vpred: (VertexId, VD) => Boolean): EdgePartition[ED, VD] = { - val filtered = tripletIterator().filter(et => - vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) - val builder = new EdgePartitionBuilder[ED, VD] - for (e <- filtered) { - builder.add(e.srcId, e.dstId, e.attr) + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet) + var i = 0 + while (i < size) { + // The user sees the EdgeTriplet, so we can't reuse it and must create one per edge. + val localSrcId = localSrcIds(i) + val localDstId = localDstIds(i) + val et = new EdgeTriplet[VD, ED] + et.srcId = local2global(localSrcId) + et.dstId = local2global(localDstId) + et.srcAttr = vertexAttrs(localSrcId) + et.dstAttr = vertexAttrs(localDstId) + et.attr = data(i) + if (vpred(et.srcId, et.srcAttr) && vpred(et.dstId, et.dstAttr) && epred(et)) { + builder.add(et.srcId, et.dstId, localSrcId, localDstId, et.attr) + } + i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -183,28 +234,40 @@ class EdgePartition[ * @return a new edge partition without duplicate edges */ def groupEdges(merge: (ED, ED) => ED): EdgePartition[ED, VD] = { - val builder = new EdgePartitionBuilder[ED, VD] + val builder = new ExistingEdgePartitionBuilder[ED, VD]( + global2local, local2global, vertexAttrs, activeSet) var currSrcId: VertexId = null.asInstanceOf[VertexId] var currDstId: VertexId = null.asInstanceOf[VertexId] + var currLocalSrcId = -1 + var currLocalDstId = -1 var currAttr: ED = null.asInstanceOf[ED] + // Iterate through the edges, accumulating runs of identical edges using the curr* variables and + // releasing them to the builder when we see the beginning of the next run var i = 0 while (i < size) { if (i > 0 && currSrcId == srcIds(i) && currDstId == dstIds(i)) { + // This edge should be accumulated into the existing run currAttr = merge(currAttr, data(i)) } else { + // This edge starts a new run of edges if (i > 0) { - builder.add(currSrcId, currDstId, currAttr) + // First release the existing run to the builder + builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr) } + // Then start accumulating for a new run currSrcId = srcIds(i) currDstId = dstIds(i) + currLocalSrcId = localSrcIds(i) + currLocalDstId = localDstIds(i) currAttr = data(i) } i += 1 } + // Finally, release the last accumulated run if (size > 0) { - builder.add(currSrcId, currDstId, currAttr) + builder.add(currSrcId, currDstId, currLocalSrcId, currLocalDstId, currAttr) } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -220,7 +283,8 @@ class EdgePartition[ def innerJoin[ED2: ClassTag, ED3: ClassTag] (other: EdgePartition[ED2, _]) (f: (VertexId, VertexId, ED, ED2) => ED3): EdgePartition[ED3, VD] = { - val builder = new EdgePartitionBuilder[ED3, VD] + val builder = new ExistingEdgePartitionBuilder[ED3, VD]( + global2local, local2global, vertexAttrs, activeSet) var i = 0 var j = 0 // For i = index of each edge in `this`... @@ -233,12 +297,13 @@ class EdgePartition[ while (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) < dstId) { j += 1 } if (j < other.size && other.srcIds(j) == srcId && other.dstIds(j) == dstId) { // ... run `f` on the matching edge - builder.add(srcId, dstId, f(srcId, dstId, this.data(i), other.data(j))) + builder.add(srcId, dstId, localSrcIds(i), localDstIds(i), + f(srcId, dstId, this.data(i), other.attrs(j))) } } i += 1 } - builder.toEdgePartition.withVertices(vertices).withActiveSet(activeSet) + builder.toEdgePartition } /** @@ -246,7 +311,7 @@ class EdgePartition[ * * @return size of the partition */ - val size: Int = srcIds.size + val size: Int = localSrcIds.size /** The number of unique source vertices in the partition. */ def indexSize: Int = index.size @@ -280,55 +345,198 @@ class EdgePartition[ * It is safe to keep references to the objects from this iterator. */ def tripletIterator( - includeSrc: Boolean = true, includeDst: Boolean = true): Iterator[EdgeTriplet[VD, ED]] = { - new EdgeTripletIterator(this, includeSrc, includeDst) + includeSrc: Boolean = true, includeDst: Boolean = true) + : Iterator[EdgeTriplet[VD, ED]] = new Iterator[EdgeTriplet[VD, ED]] { + private[this] var pos = 0 + + override def hasNext: Boolean = pos < EdgePartition.this.size + + override def next() = { + val triplet = new EdgeTriplet[VD, ED] + val localSrcId = localSrcIds(pos) + val localDstId = localDstIds(pos) + triplet.srcId = local2global(localSrcId) + triplet.dstId = local2global(localDstId) + if (includeSrc) { + triplet.srcAttr = vertexAttrs(localSrcId) + } + if (includeDst) { + triplet.dstAttr = vertexAttrs(localDstId) + } + triplet.attr = data(pos) + pos += 1 + triplet + } } /** - * Upgrade the given edge iterator into a triplet iterator. + * Send messages along edges and aggregate them at the receiving vertices. Implemented by scanning + * all edges sequentially. * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. + * @param sendMsg generates messages to neighboring vertices of an edge + * @param mergeMsg the combiner applied to messages destined to the same vertex + * @param tripletFields which triplet fields `sendMsg` uses + * @param activeness criteria for filtering edges based on activeness + * + * @return iterator aggregated messages keyed by the receiving vertex id */ - def upgradeIterator( - edgeIter: Iterator[Edge[ED]], includeSrc: Boolean = true, includeDst: Boolean = true) - : Iterator[EdgeTriplet[VD, ED]] = { - new ReusingEdgeTripletIterator(edgeIter, this, includeSrc, includeDst) + def aggregateMessagesEdgeScan[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeness: EdgeActiveness): Iterator[(VertexId, A)] = { + val aggregates = new Array[A](vertexAttrs.length) + val bitset = new BitSet(vertexAttrs.length) + + var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset) + var i = 0 + while (i < size) { + val localSrcId = localSrcIds(i) + val srcId = local2global(localSrcId) + val localDstId = localDstIds(i) + val dstId = local2global(localDstId) + val edgeIsActive = + if (activeness == EdgeActiveness.Neither) true + else if (activeness == EdgeActiveness.SrcOnly) isActive(srcId) + else if (activeness == EdgeActiveness.DstOnly) isActive(dstId) + else if (activeness == EdgeActiveness.Both) isActive(srcId) && isActive(dstId) + else if (activeness == EdgeActiveness.Either) isActive(srcId) || isActive(dstId) + else throw new Exception("unreachable") + if (edgeIsActive) { + val srcAttr = if (tripletFields.useSrc) vertexAttrs(localSrcId) else null.asInstanceOf[VD] + val dstAttr = if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.set(srcId, dstId, localSrcId, localDstId, srcAttr, dstAttr, data(i)) + sendMsg(ctx) + } + i += 1 + } + + bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) } } /** - * Get an iterator over the edges in this partition whose source vertex ids match srcIdPred. The - * iterator is generated using an index scan, so it is efficient at skipping edges that don't - * match srcIdPred. + * Send messages along edges and aggregate them at the receiving vertices. Implemented by + * filtering the source vertex index, then scanning each edge cluster. * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. - */ - def indexIterator(srcIdPred: VertexId => Boolean): Iterator[Edge[ED]] = - index.iterator.filter(kv => srcIdPred(kv._1)).flatMap(Function.tupled(clusterIterator)) - - /** - * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The - * cluster must start at position `index`. + * @param sendMsg generates messages to neighboring vertices of an edge + * @param mergeMsg the combiner applied to messages destined to the same vertex + * @param tripletFields which triplet fields `sendMsg` uses + * @param activeness criteria for filtering edges based on activeness * - * Be careful not to keep references to the objects from this iterator. To improve GC performance - * the same object is re-used in `next()`. + * @return iterator aggregated messages keyed by the receiving vertex id */ - private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] { - private[this] val edge = new Edge[ED] - private[this] var pos = index + def aggregateMessagesIndexScan[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeness: EdgeActiveness): Iterator[(VertexId, A)] = { + val aggregates = new Array[A](vertexAttrs.length) + val bitset = new BitSet(vertexAttrs.length) + + var ctx = new AggregatingEdgeContext[VD, ED, A](mergeMsg, aggregates, bitset) + index.iterator.foreach { cluster => + val clusterSrcId = cluster._1 + val clusterPos = cluster._2 + val clusterLocalSrcId = localSrcIds(clusterPos) - override def hasNext: Boolean = { - pos >= 0 && pos < EdgePartition.this.size && srcIds(pos) == srcId + val scanCluster = + if (activeness == EdgeActiveness.Neither) true + else if (activeness == EdgeActiveness.SrcOnly) isActive(clusterSrcId) + else if (activeness == EdgeActiveness.DstOnly) true + else if (activeness == EdgeActiveness.Both) isActive(clusterSrcId) + else if (activeness == EdgeActiveness.Either) true + else throw new Exception("unreachable") + + if (scanCluster) { + var pos = clusterPos + val srcAttr = + if (tripletFields.useSrc) vertexAttrs(clusterLocalSrcId) else null.asInstanceOf[VD] + ctx.setSrcOnly(clusterSrcId, clusterLocalSrcId, srcAttr) + while (pos < size && localSrcIds(pos) == clusterLocalSrcId) { + val localDstId = localDstIds(pos) + val dstId = local2global(localDstId) + val edgeIsActive = + if (activeness == EdgeActiveness.Neither) true + else if (activeness == EdgeActiveness.SrcOnly) true + else if (activeness == EdgeActiveness.DstOnly) isActive(dstId) + else if (activeness == EdgeActiveness.Both) isActive(dstId) + else if (activeness == EdgeActiveness.Either) isActive(clusterSrcId) || isActive(dstId) + else throw new Exception("unreachable") + if (edgeIsActive) { + val dstAttr = + if (tripletFields.useDst) vertexAttrs(localDstId) else null.asInstanceOf[VD] + ctx.setRest(dstId, localDstId, dstAttr, data(pos)) + sendMsg(ctx) + } + pos += 1 + } + } } - override def next(): Edge[ED] = { - assert(srcIds(pos) == srcId) - edge.srcId = srcIds(pos) - edge.dstId = dstIds(pos) - edge.attr = data(pos) - pos += 1 - edge + bitset.iterator.map { localId => (local2global(localId), aggregates(localId)) } + } +} + +private class AggregatingEdgeContext[VD, ED, A]( + mergeMsg: (A, A) => A, + aggregates: Array[A], + bitset: BitSet) + extends EdgeContext[VD, ED, A] { + + private[this] var _srcId: VertexId = _ + private[this] var _dstId: VertexId = _ + private[this] var _localSrcId: Int = _ + private[this] var _localDstId: Int = _ + private[this] var _srcAttr: VD = _ + private[this] var _dstAttr: VD = _ + private[this] var _attr: ED = _ + + def set( + srcId: VertexId, dstId: VertexId, + localSrcId: Int, localDstId: Int, + srcAttr: VD, dstAttr: VD, + attr: ED) { + _srcId = srcId + _dstId = dstId + _localSrcId = localSrcId + _localDstId = localDstId + _srcAttr = srcAttr + _dstAttr = dstAttr + _attr = attr + } + + def setSrcOnly(srcId: VertexId, localSrcId: Int, srcAttr: VD) { + _srcId = srcId + _localSrcId = localSrcId + _srcAttr = srcAttr + } + + def setRest(dstId: VertexId, localDstId: Int, dstAttr: VD, attr: ED) { + _dstId = dstId + _localDstId = localDstId + _dstAttr = dstAttr + _attr = attr + } + + override def srcId = _srcId + override def dstId = _dstId + override def srcAttr = _srcAttr + override def dstAttr = _dstAttr + override def attr = _attr + + override def sendToSrc(msg: A) { + send(_localSrcId, msg) + } + override def sendToDst(msg: A) { + send(_localDstId, msg) + } + + @inline private def send(localId: Int, msg: A) { + if (bitset.get(localId)) { + aggregates(localId) = mergeMsg(aggregates(localId), msg) + } else { + aggregates(localId) = msg + bitset.set(localId) } } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 4520beb991515..b0cb0fe47d461 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -25,10 +25,11 @@ import org.apache.spark.util.collection.{BitSet, OpenHashSet, PrimitiveVector} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +/** Constructs an EdgePartition from scratch. */ private[graphx] class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( size: Int = 64) { - var edges = new PrimitiveVector[Edge[ED]](size) + private[this] val edges = new PrimitiveVector[Edge[ED]](size) /** Add a new edge to the partition. */ def add(src: VertexId, dst: VertexId, d: ED) { @@ -38,19 +39,78 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla def toEdgePartition: EdgePartition[ED, VD] = { val edgeArray = edges.trim().array Sorting.quickSort(edgeArray)(Edge.lexicographicOrdering) - val srcIds = new Array[VertexId](edgeArray.size) - val dstIds = new Array[VertexId](edgeArray.size) + val localSrcIds = new Array[Int](edgeArray.size) + val localDstIds = new Array[Int](edgeArray.size) + val data = new Array[ED](edgeArray.size) + val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] + val global2local = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] + val local2global = new PrimitiveVector[VertexId] + var vertexAttrs = Array.empty[VD] + // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and + // adding them to the index. Also populate a map from vertex id to a sequential local offset. + if (edgeArray.length > 0) { + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId + var currLocalId = -1 + var i = 0 + while (i < edgeArray.size) { + val srcId = edgeArray(i).srcId + val dstId = edgeArray(i).dstId + localSrcIds(i) = global2local.changeValue(srcId, + { currLocalId += 1; local2global += srcId; currLocalId }, identity) + localDstIds(i) = global2local.changeValue(dstId, + { currLocalId += 1; local2global += dstId; currLocalId }, identity) + data(i) = edgeArray(i).attr + if (srcId != currSrcId) { + currSrcId = srcId + index.update(currSrcId, i) + } + + i += 1 + } + vertexAttrs = new Array[VD](currLocalId + 1) + } + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global.trim().array, vertexAttrs, + None) + } +} + +/** + * Constructs an EdgePartition from an existing EdgePartition with the same vertex set. This enables + * reuse of the local vertex ids. Intended for internal use in EdgePartition only. + */ +private[impl] +class ExistingEdgePartitionBuilder[ + @specialized(Long, Int, Double) ED: ClassTag, VD: ClassTag]( + global2local: GraphXPrimitiveKeyOpenHashMap[VertexId, Int], + local2global: Array[VertexId], + vertexAttrs: Array[VD], + activeSet: Option[VertexSet], + size: Int = 64) { + private[this] val edges = new PrimitiveVector[EdgeWithLocalIds[ED]](size) + + /** Add a new edge to the partition. */ + def add(src: VertexId, dst: VertexId, localSrc: Int, localDst: Int, d: ED) { + edges += EdgeWithLocalIds(src, dst, localSrc, localDst, d) + } + + def toEdgePartition: EdgePartition[ED, VD] = { + val edgeArray = edges.trim().array + Sorting.quickSort(edgeArray)(EdgeWithLocalIds.lexicographicOrdering) + val localSrcIds = new Array[Int](edgeArray.size) + val localDstIds = new Array[Int](edgeArray.size) val data = new Array[ED](edgeArray.size) val index = new GraphXPrimitiveKeyOpenHashMap[VertexId, Int] // Copy edges into columnar structures, tracking the beginnings of source vertex id clusters and // adding them to the index if (edgeArray.length > 0) { - index.update(srcIds(0), 0) - var currSrcId: VertexId = srcIds(0) + index.update(edgeArray(0).srcId, 0) + var currSrcId: VertexId = edgeArray(0).srcId var i = 0 while (i < edgeArray.size) { - srcIds(i) = edgeArray(i).srcId - dstIds(i) = edgeArray(i).dstId + localSrcIds(i) = edgeArray(i).localSrcId + localDstIds(i) = edgeArray(i).localDstId data(i) = edgeArray(i).attr if (edgeArray(i).srcId != currSrcId) { currSrcId = edgeArray(i).srcId @@ -60,13 +120,24 @@ class EdgePartitionBuilder[@specialized(Long, Int, Double) ED: ClassTag, VD: Cla } } - // Create and populate a VertexPartition with vids from the edges, but no attributes - val vidsIter = srcIds.iterator ++ dstIds.iterator - val vertexIds = new OpenHashSet[VertexId] - vidsIter.foreach(vid => vertexIds.add(vid)) - val vertices = new VertexPartition( - vertexIds, new Array[VD](vertexIds.capacity), vertexIds.getBitSet) + new EdgePartition( + localSrcIds, localDstIds, data, index, global2local, local2global, vertexAttrs, activeSet) + } +} - new EdgePartition(srcIds, dstIds, data, index, vertices) +private[impl] case class EdgeWithLocalIds[@specialized ED]( + srcId: VertexId, dstId: VertexId, localSrcId: Int, localDstId: Int, attr: ED) + +private[impl] object EdgeWithLocalIds { + implicit def lexicographicOrdering[ED] = new Ordering[EdgeWithLocalIds[ED]] { + override def compare(a: EdgeWithLocalIds[ED], b: EdgeWithLocalIds[ED]): Int = { + if (a.srcId == b.srcId) { + if (a.dstId == b.dstId) 0 + else if (a.dstId < b.dstId) -1 + else 1 + } else if (a.srcId < b.srcId) -1 + else 1 + } } + } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala new file mode 100644 index 0000000000000..a8169613b4fd2 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.{classTag, ClassTag} + +import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +import org.apache.spark.graphx._ + +class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( + override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], + val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) + extends EdgeRDD[ED](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { + + override def setName(_name: String): this.type = { + if (partitionsRDD.name != null) { + partitionsRDD.setName(partitionsRDD.name + ", " + _name) + } else { + partitionsRDD.setName(_name) + } + this + } + setName("EdgeRDD") + + /** + * If `partitionsRDD` already has a partitioner, use it. Otherwise assume that the + * [[PartitionID]]s in `partitionsRDD` correspond to the actual partitions and create a new + * partitioner that allows co-partitioning with `partitionsRDD`. + */ + override val partitioner = + partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) + + override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() + + /** + * Persists the edge partitions at the specified storage level, ignoring any existing target + * storage level. + */ + override def persist(newLevel: StorageLevel): this.type = { + partitionsRDD.persist(newLevel) + this + } + + override def unpersist(blocking: Boolean = true): this.type = { + partitionsRDD.unpersist(blocking) + this + } + + /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */ + override def cache(): this.type = { + partitionsRDD.persist(targetStorageLevel) + this + } + + /** The number of edges in the RDD. */ + override def count(): Long = { + partitionsRDD.map(_._2.size.toLong).reduce(_ + _) + } + + override def mapValues[ED2: ClassTag](f: Edge[ED] => ED2): EdgeRDDImpl[ED2, VD] = + mapEdgePartitions((pid, part) => part.map(f)) + + override def reverse: EdgeRDDImpl[ED, VD] = mapEdgePartitions((pid, part) => part.reverse) + + def filter( + epred: EdgeTriplet[VD, ED] => Boolean, + vpred: (VertexId, VD) => Boolean): EdgeRDDImpl[ED, VD] = { + mapEdgePartitions((pid, part) => part.filter(epred, vpred)) + } + + override def innerJoin[ED2: ClassTag, ED3: ClassTag] + (other: EdgeRDD[ED2]) + (f: (VertexId, VertexId, ED, ED2) => ED3): EdgeRDDImpl[ED3, VD] = { + val ed2Tag = classTag[ED2] + val ed3Tag = classTag[ED3] + this.withPartitionsRDD[ED3, VD](partitionsRDD.zipPartitions(other.partitionsRDD, true) { + (thisIter, otherIter) => + val (pid, thisEPart) = thisIter.next() + val (_, otherEPart) = otherIter.next() + Iterator(Tuple2(pid, thisEPart.innerJoin(otherEPart)(f)(ed2Tag, ed3Tag))) + }) + } + + def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag]( + f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDDImpl[ED2, VD2] = { + this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter => + if (iter.hasNext) { + val (pid, ep) = iter.next() + Iterator(Tuple2(pid, f(pid, ep))) + } else { + Iterator.empty + } + }, preservesPartitioning = true)) + } + + private[graphx] def withPartitionsRDD[ED2: ClassTag, VD2: ClassTag]( + partitionsRDD: RDD[(PartitionID, EdgePartition[ED2, VD2])]): EdgeRDDImpl[ED2, VD2] = { + new EdgeRDDImpl(partitionsRDD, this.targetStorageLevel) + } + + override private[graphx] def withTargetStorageLevel( + targetStorageLevel: StorageLevel): EdgeRDDImpl[ED, VD] = { + new EdgeRDDImpl(this.partitionsRDD, targetStorageLevel) + } + +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala deleted file mode 100644 index 56f79a7097fce..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.reflect.ClassTag - -import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -/** - * The Iterator type returned when constructing edge triplets. This could be an anonymous class in - * EdgePartition.tripletIterator, but we name it here explicitly so it is easier to debug / profile. - */ -private[impl] -class EdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val edgePartition: EdgePartition[ED, VD], - val includeSrc: Boolean, - val includeDst: Boolean) - extends Iterator[EdgeTriplet[VD, ED]] { - - // Current position in the array. - private var pos = 0 - - override def hasNext: Boolean = pos < edgePartition.size - - override def next() = { - val triplet = new EdgeTriplet[VD, ED] - triplet.srcId = edgePartition.srcIds(pos) - if (includeSrc) { - triplet.srcAttr = edgePartition.vertices(triplet.srcId) - } - triplet.dstId = edgePartition.dstIds(pos) - if (includeDst) { - triplet.dstAttr = edgePartition.vertices(triplet.dstId) - } - triplet.attr = edgePartition.data(pos) - pos += 1 - triplet - } -} - -/** - * An Iterator type for internal use that reuses EdgeTriplet objects. This could be an anonymous - * class in EdgePartition.upgradeIterator, but we name it here explicitly so it is easier to debug / - * profile. - */ -private[impl] -class ReusingEdgeTripletIterator[VD: ClassTag, ED: ClassTag]( - val edgeIter: Iterator[Edge[ED]], - val edgePartition: EdgePartition[ED, VD], - val includeSrc: Boolean, - val includeDst: Boolean) - extends Iterator[EdgeTriplet[VD, ED]] { - - private val triplet = new EdgeTriplet[VD, ED] - - override def hasNext = edgeIter.hasNext - - override def next() = { - triplet.set(edgeIter.next()) - if (includeSrc) { - triplet.srcAttr = edgePartition.vertices(triplet.srcId) - } - if (includeDst) { - triplet.dstAttr = edgePartition.vertices(triplet.dstId) - } - triplet - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 33f35cfb69a26..0eae2a673874a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -23,7 +23,6 @@ import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.{RDD, ShuffledRDD} import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ import org.apache.spark.graphx.util.BytecodeUtils @@ -44,7 +43,7 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( /** Default constructor is provided to support serialization */ protected def this() = this(null, null) - @transient override val edges: EdgeRDD[ED, VD] = replicatedVertexView.edges + @transient override val edges: EdgeRDDImpl[ED, VD] = replicatedVertexView.edges /** Return a RDD that brings edges together with their source and destination vertices. */ @transient override lazy val triplets: RDD[EdgeTriplet[VD, ED]] = { @@ -127,13 +126,12 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } override def mapTriplets[ED2: ClassTag]( - f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2]): Graph[VD, ED2] = { + f: (PartitionID, Iterator[EdgeTriplet[VD, ED]]) => Iterator[ED2], + tripletFields: TripletFields): Graph[VD, ED2] = { vertices.cache() - val mapUsesSrcAttr = accessesVertexAttr(f, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(f, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val newEdges = replicatedVertexView.edges.mapEdgePartitions { (pid, part) => - part.map(f(pid, part.tripletIterator(mapUsesSrcAttr, mapUsesDstAttr))) + part.map(f(pid, part.tripletIterator(tripletFields.useSrc, tripletFields.useDst))) } new GraphImpl(vertices, replicatedVertexView.withEdges(newEdges)) } @@ -171,15 +169,38 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def mapReduceTriplets[A: ClassTag]( mapFunc: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)], reduceFunc: (A, A) => A, - activeSetOpt: Option[(VertexRDD[_], EdgeDirection)] = None): VertexRDD[A] = { + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { + + def sendMsg(ctx: EdgeContext[VD, ED, A]) { + mapFunc(ctx.toEdgeTriplet).foreach { kv => + val id = kv._1 + val msg = kv._2 + if (id == ctx.srcId) { + ctx.sendToSrc(msg) + } else { + assert(id == ctx.dstId) + ctx.sendToDst(msg) + } + } + } - vertices.cache() + val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") + val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") + val tripletFields = new TripletFields(mapUsesSrcAttr, mapUsesDstAttr, true) + + aggregateMessagesWithActiveSet(sendMsg, reduceFunc, tripletFields, activeSetOpt) + } + + override def aggregateMessagesWithActiveSet[A: ClassTag]( + sendMsg: EdgeContext[VD, ED, A] => Unit, + mergeMsg: (A, A) => A, + tripletFields: TripletFields, + activeSetOpt: Option[(VertexRDD[_], EdgeDirection)]): VertexRDD[A] = { + vertices.cache() // For each vertex, replicate its attribute only to partitions where it is // in the relevant position in an edge. - val mapUsesSrcAttr = accessesVertexAttr(mapFunc, "srcAttr") - val mapUsesDstAttr = accessesVertexAttr(mapFunc, "dstAttr") - replicatedVertexView.upgrade(vertices, mapUsesSrcAttr, mapUsesDstAttr) + replicatedVertexView.upgrade(vertices, tripletFields.useSrc, tripletFields.useDst) val view = activeSetOpt match { case Some((activeSet, _)) => replicatedVertexView.withActiveSet(activeSet) @@ -193,42 +214,40 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( case (pid, edgePartition) => // Choose scan method val activeFraction = edgePartition.numActives.getOrElse(0) / edgePartition.indexSize.toFloat - val edgeIter = activeDirectionOpt match { + activeDirectionOpt match { case Some(EdgeDirection.Both) => if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) - .filter(e => edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Both) } else { - edgePartition.iterator.filter(e => - edgePartition.isActive(e.srcId) && edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Both) } case Some(EdgeDirection.Either) => // TODO: Because we only have a clustered index on the source vertex ID, we can't filter // the index here. Instead we have to scan all edges and then do the filter. - edgePartition.iterator.filter(e => - edgePartition.isActive(e.srcId) || edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Either) case Some(EdgeDirection.Out) => if (activeFraction < 0.8) { - edgePartition.indexIterator(srcVertexId => edgePartition.isActive(srcVertexId)) + edgePartition.aggregateMessagesIndexScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.SrcOnly) } else { - edgePartition.iterator.filter(e => edgePartition.isActive(e.srcId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.SrcOnly) } case Some(EdgeDirection.In) => - edgePartition.iterator.filter(e => edgePartition.isActive(e.dstId)) + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.DstOnly) case _ => // None - edgePartition.iterator + edgePartition.aggregateMessagesEdgeScan(sendMsg, mergeMsg, tripletFields, + EdgeActiveness.Neither) } - - // Scan edges and run the map function - val mapOutputs = edgePartition.upgradeIterator(edgeIter, mapUsesSrcAttr, mapUsesDstAttr) - .flatMap(mapFunc(_)) - // Note: This doesn't allow users to send messages to arbitrary vertices. - edgePartition.vertices.aggregateUsingIndex(mapOutputs, reduceFunc).iterator - }).setName("GraphImpl.mapReduceTriplets - preAgg") + }).setName("GraphImpl.aggregateMessages - preAgg") // do the final reduction reusing the index map - vertices.aggregateUsingIndex(preAgg, reduceFunc) - } // end of mapReduceTriplets + vertices.aggregateUsingIndex(preAgg, mergeMsg) + } override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) @@ -304,11 +323,10 @@ object GraphImpl { */ def apply[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED, _]): GraphImpl[VD, ED] = { + edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { // Convert the vertex partitions in edges to the correct type - val newEdges = edges.mapEdgePartitions( - (pid, part) => part.withVertices(part.vertices.map( - (vid, attr) => null.asInstanceOf[VD]))) + val newEdges = edges.asInstanceOf[EdgeRDDImpl[ED, _]] + .mapEdgePartitions((pid, part) => part.withoutVertexAttributes[VD]) GraphImpl.fromExistingRDDs(vertices, newEdges) } @@ -319,8 +337,8 @@ object GraphImpl { */ def fromExistingRDDs[VD: ClassTag, ED: ClassTag]( vertices: VertexRDD[VD], - edges: EdgeRDD[ED, VD]): GraphImpl[VD, ED] = { - new GraphImpl(vertices, new ReplicatedVertexView(edges)) + edges: EdgeRDD[ED]): GraphImpl[VD, ED] = { + new GraphImpl(vertices, new ReplicatedVertexView(edges.asInstanceOf[EdgeRDDImpl[ED, VD]])) } /** @@ -328,7 +346,7 @@ object GraphImpl { * `defaultVertexAttr`. The vertices will have the same number of partitions as the EdgeRDD. */ private def fromEdgeRDD[VD: ClassTag, ED: ClassTag]( - edges: EdgeRDD[ED, VD], + edges: EdgeRDDImpl[ED, VD], defaultVertexAttr: VD, edgeStorageLevel: StorageLevel, vertexStorageLevel: StorageLevel): GraphImpl[VD, ED] = { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala deleted file mode 100644 index 714f3b81c9dad..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.language.implicitConversions -import scala.reflect.{classTag, ClassTag} - -import org.apache.spark.Partitioner -import org.apache.spark.graphx.{PartitionID, VertexId} -import org.apache.spark.rdd.{ShuffledRDD, RDD} - - -private[graphx] -class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) { - def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = { - val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner) - - // Set a custom serializer if the data is of int or double type. - if (classTag[VD] == ClassTag.Int) { - rdd.setSerializer(new IntAggMsgSerializer) - } else if (classTag[VD] == ClassTag.Long) { - rdd.setSerializer(new LongAggMsgSerializer) - } else if (classTag[VD] == ClassTag.Double) { - rdd.setSerializer(new DoubleAggMsgSerializer) - } - rdd - } -} - -private[graphx] -object VertexRDDFunctions { - implicit def rdd2VertexRDDFunctions[VD: ClassTag](rdd: RDD[(VertexId, VD)]) = { - new VertexRDDFunctions(rdd) - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 86b366eb9202b..8ab255bd4038c 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -33,7 +33,7 @@ import org.apache.spark.graphx._ */ private[impl] class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( - var edges: EdgeRDD[ED, VD], + var edges: EdgeRDDImpl[ED, VD], var hasSrcId: Boolean = false, var hasDstId: Boolean = false) { @@ -42,7 +42,7 @@ class ReplicatedVertexView[VD: ClassTag, ED: ClassTag]( * shipping level. */ def withEdges[VD2: ClassTag, ED2: ClassTag]( - edges_ : EdgeRDD[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { + edges_ : EdgeRDDImpl[ED2, VD2]): ReplicatedVertexView[VD2, ED2] = { new ReplicatedVertexView(edges_, hasSrcId, hasDstId) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index b27485953f719..eb3c997e0f3c0 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -29,24 +29,6 @@ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage -private[graphx] -class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) { - /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */ - def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = { - new ShuffledRDD[VertexId, Int, Int]( - self, partitioner).setSerializer(new RoutingTableMessageSerializer) - } -} - -private[graphx] -object RoutingTableMessageRDDFunctions { - import scala.language.implicitConversions - - implicit def rdd2RoutingTableMessageRDDFunctions(rdd: RDD[RoutingTableMessage]) = { - new RoutingTableMessageRDDFunctions(rdd) - } -} - private[graphx] object RoutingTablePartition { /** @@ -74,11 +56,9 @@ object RoutingTablePartition { // Determine which positions each vertex id appears in using a map where the low 2 bits // represent src and dst val map = new GraphXPrimitiveKeyOpenHashMap[VertexId, Byte] - edgePartition.srcIds.iterator.foreach { srcId => - map.changeValue(srcId, 0x1, (b: Byte) => (b | 0x1).toByte) - } - edgePartition.dstIds.iterator.foreach { dstId => - map.changeValue(dstId, 0x2, (b: Byte) => (b | 0x2).toByte) + edgePartition.iterator.foreach { e => + map.changeValue(e.srcId, 0x1, (b: Byte) => (b | 0x1).toByte) + map.changeValue(e.dstId, 0x2, (b: Byte) => (b | 0x2).toByte) } map.iterator.map { vidAndPosition => val vid = vidAndPosition._1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala deleted file mode 100644 index 3909efcdfc993..0000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala +++ /dev/null @@ -1,369 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.language.existentials - -import java.io.{EOFException, InputStream, OutputStream} -import java.nio.ByteBuffer - -import scala.reflect.ClassTag - -import org.apache.spark.serializer._ - -import org.apache.spark.graphx._ -import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage - -private[graphx] -class RoutingTableMessageSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream): SerializationStream = - new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T): SerializationStream = { - val msg = t.asInstanceOf[RoutingTableMessage] - writeVarLong(msg._1, optimizePositive = false) - writeInt(msg._2) - this - } - } - - override def deserializeStream(s: InputStream): DeserializationStream = - new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readInt() - (a, b).asInstanceOf[T] - } - } - } -} - -private[graphx] -class VertexIdMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, _)] - writeVarLong(msg._1, optimizePositive = false) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - (readVarLong(optimizePositive = false), null).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Int]. */ -private[graphx] -class IntAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Int)] - writeVarLong(msg._1, optimizePositive = false) - writeUnsignedVarInt(msg._2) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readUnsignedVarInt() - (a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Long]. */ -private[graphx] -class LongAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Long)] - writeVarLong(msg._1, optimizePositive = false) - writeVarLong(msg._2, optimizePositive = true) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - override def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readVarLong(optimizePositive = true) - (a, b).asInstanceOf[T] - } - } - } -} - -/** A special shuffle serializer for AggregationMessage[Double]. */ -private[graphx] -class DoubleAggMsgSerializer extends Serializer with Serializable { - override def newInstance(): SerializerInstance = new ShuffleSerializerInstance { - - override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) { - def writeObject[T: ClassTag](t: T) = { - val msg = t.asInstanceOf[(VertexId, Double)] - writeVarLong(msg._1, optimizePositive = false) - writeDouble(msg._2) - this - } - } - - override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) { - def readObject[T: ClassTag](): T = { - val a = readVarLong(optimizePositive = false) - val b = readDouble() - (a, b).asInstanceOf[T] - } - } - } -} - -//////////////////////////////////////////////////////////////////////////////// -// Helper classes to shorten the implementation of those special serializers. -//////////////////////////////////////////////////////////////////////////////// - -private[graphx] -abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream { - // The implementation should override this one. - def writeObject[T: ClassTag](t: T): SerializationStream - - def writeInt(v: Int) { - s.write(v >> 24) - s.write(v >> 16) - s.write(v >> 8) - s.write(v) - } - - def writeUnsignedVarInt(value: Int) { - if ((value >>> 7) == 0) { - s.write(value.toInt) - } else if ((value >>> 14) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7) - } else if ((value >>> 21) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14) - } else if ((value >>> 28) == 0) { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14 | 0x80) - s.write(value >>> 21) - } else { - s.write((value & 0x7F) | 0x80) - s.write(value >>> 7 | 0x80) - s.write(value >>> 14 | 0x80) - s.write(value >>> 21 | 0x80) - s.write(value >>> 28) - } - } - - def writeVarLong(value: Long, optimizePositive: Boolean) { - val v = if (!optimizePositive) (value << 1) ^ (value >> 63) else value - if ((v >>> 7) == 0) { - s.write(v.toInt) - } else if ((v >>> 14) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7).toInt) - } else if ((v >>> 21) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14).toInt) - } else if ((v >>> 28) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21).toInt) - } else if ((v >>> 35) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28).toInt) - } else if ((v >>> 42) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35).toInt) - } else if ((v >>> 49) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42).toInt) - } else if ((v >>> 56) == 0) { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42 | 0x80).toInt) - s.write((v >>> 49).toInt) - } else { - s.write(((v & 0x7F) | 0x80).toInt) - s.write((v >>> 7 | 0x80).toInt) - s.write((v >>> 14 | 0x80).toInt) - s.write((v >>> 21 | 0x80).toInt) - s.write((v >>> 28 | 0x80).toInt) - s.write((v >>> 35 | 0x80).toInt) - s.write((v >>> 42 | 0x80).toInt) - s.write((v >>> 49 | 0x80).toInt) - s.write((v >>> 56).toInt) - } - } - - def writeLong(v: Long) { - s.write((v >>> 56).toInt) - s.write((v >>> 48).toInt) - s.write((v >>> 40).toInt) - s.write((v >>> 32).toInt) - s.write((v >>> 24).toInt) - s.write((v >>> 16).toInt) - s.write((v >>> 8).toInt) - s.write(v.toInt) - } - - def writeDouble(v: Double): Unit = writeLong(java.lang.Double.doubleToLongBits(v)) - - override def flush(): Unit = s.flush() - - override def close(): Unit = s.close() -} - -private[graphx] -abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream { - // The implementation should override this one. - def readObject[T: ClassTag](): T - - def readInt(): Int = { - val first = s.read() - if (first < 0) throw new EOFException - (first & 0xFF) << 24 | (s.read() & 0xFF) << 16 | (s.read() & 0xFF) << 8 | (s.read() & 0xFF) - } - - def readUnsignedVarInt(): Int = { - var value: Int = 0 - var i: Int = 0 - def readOrThrow(): Int = { - val in = s.read() - if (in < 0) throw new EOFException - in & 0xFF - } - var b: Int = readOrThrow() - while ((b & 0x80) != 0) { - value |= (b & 0x7F) << i - i += 7 - if (i > 35) throw new IllegalArgumentException("Variable length quantity is too long") - b = readOrThrow() - } - value | (b << i) - } - - def readVarLong(optimizePositive: Boolean): Long = { - def readOrThrow(): Int = { - val in = s.read() - if (in < 0) throw new EOFException - in & 0xFF - } - var b = readOrThrow() - var ret: Long = b & 0x7F - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 7 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 14 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F) << 21 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 28 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 35 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 42 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= (b & 0x7F).toLong << 49 - if ((b & 0x80) != 0) { - b = readOrThrow() - ret |= b.toLong << 56 - } - } - } - } - } - } - } - } - if (!optimizePositive) (ret >>> 1) ^ -(ret & 1) else ret - } - - def readLong(): Long = { - val first = s.read() - if (first < 0) throw new EOFException() - (first.toLong << 56) | - (s.read() & 0xFF).toLong << 48 | - (s.read() & 0xFF).toLong << 40 | - (s.read() & 0xFF).toLong << 32 | - (s.read() & 0xFF).toLong << 24 | - (s.read() & 0xFF) << 16 | - (s.read() & 0xFF) << 8 | - (s.read() & 0xFF) - } - - def readDouble(): Double = java.lang.Double.longBitsToDouble(readLong()) - - override def close(): Unit = s.close() -} - -private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance { - - override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer): T = - throw new UnsupportedOperationException - - override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = - throw new UnsupportedOperationException - - // The implementation should override the following two. - override def serializeStream(s: OutputStream): SerializationStream - override def deserializeStream(s: InputStream): DeserializationStream -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala new file mode 100644 index 0000000000000..d92a55a189298 --- /dev/null +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.ClassTag + +import org.apache.spark._ +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel + +import org.apache.spark.graphx._ + +class VertexRDDImpl[VD] private[graphx] ( + val partitionsRDD: RDD[ShippableVertexPartition[VD]], + val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) + (implicit override protected val vdTag: ClassTag[VD]) + extends VertexRDD[VD](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { + + require(partitionsRDD.partitioner.isDefined) + + override def reindex(): VertexRDD[VD] = this.withPartitionsRDD(partitionsRDD.map(_.reindex())) + + override val partitioner = partitionsRDD.partitioner + + override protected def getPreferredLocations(s: Partition): Seq[String] = + partitionsRDD.preferredLocations(s) + + override def setName(_name: String): this.type = { + if (partitionsRDD.name != null) { + partitionsRDD.setName(partitionsRDD.name + ", " + _name) + } else { + partitionsRDD.setName(_name) + } + this + } + setName("VertexRDD") + + /** + * Persists the vertex partitions at the specified storage level, ignoring any existing target + * storage level. + */ + override def persist(newLevel: StorageLevel): this.type = { + partitionsRDD.persist(newLevel) + this + } + + override def unpersist(blocking: Boolean = true): this.type = { + partitionsRDD.unpersist(blocking) + this + } + + /** Persists the vertex partitions at `targetStorageLevel`, which defaults to MEMORY_ONLY. */ + override def cache(): this.type = { + partitionsRDD.persist(targetStorageLevel) + this + } + + /** The number of vertices in the RDD. */ + override def count(): Long = { + partitionsRDD.map(_.size).reduce(_ + _) + } + + override private[graphx] def mapVertexPartitions[VD2: ClassTag]( + f: ShippableVertexPartition[VD] => ShippableVertexPartition[VD2]) + : VertexRDD[VD2] = { + val newPartitionsRDD = partitionsRDD.mapPartitions(_.map(f), preservesPartitioning = true) + this.withPartitionsRDD(newPartitionsRDD) + } + + override def mapValues[VD2: ClassTag](f: VD => VD2): VertexRDD[VD2] = + this.mapVertexPartitions(_.map((vid, attr) => f(attr))) + + override def mapValues[VD2: ClassTag](f: (VertexId, VD) => VD2): VertexRDD[VD2] = + this.mapVertexPartitions(_.map(f)) + + override def diff(other: VertexRDD[VD]): VertexRDD[VD] = { + val newPartitionsRDD = partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true + ) { (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.diff(otherPart)) + } + this.withPartitionsRDD(newPartitionsRDD) + } + + override def leftZipJoin[VD2: ClassTag, VD3: ClassTag] + (other: VertexRDD[VD2])(f: (VertexId, VD, Option[VD2]) => VD3): VertexRDD[VD3] = { + val newPartitionsRDD = partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true + ) { (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.leftJoin(otherPart)(f)) + } + this.withPartitionsRDD(newPartitionsRDD) + } + + override def leftJoin[VD2: ClassTag, VD3: ClassTag] + (other: RDD[(VertexId, VD2)]) + (f: (VertexId, VD, Option[VD2]) => VD3) + : VertexRDD[VD3] = { + // Test if the other vertex is a VertexRDD to choose the optimal join strategy. + // If the other set is a VertexRDD then we use the much more efficient leftZipJoin + other match { + case other: VertexRDD[_] => + leftZipJoin(other)(f) + case _ => + this.withPartitionsRDD[VD3]( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.leftJoin(msgs)(f)) + } + ) + } + } + + override def innerZipJoin[U: ClassTag, VD2: ClassTag](other: VertexRDD[U]) + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { + val newPartitionsRDD = partitionsRDD.zipPartitions( + other.partitionsRDD, preservesPartitioning = true + ) { (thisIter, otherIter) => + val thisPart = thisIter.next() + val otherPart = otherIter.next() + Iterator(thisPart.innerJoin(otherPart)(f)) + } + this.withPartitionsRDD(newPartitionsRDD) + } + + override def innerJoin[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) + (f: (VertexId, VD, U) => VD2): VertexRDD[VD2] = { + // Test if the other vertex is a VertexRDD to choose the optimal join strategy. + // If the other set is a VertexRDD then we use the much more efficient innerZipJoin + other match { + case other: VertexRDD[_] => + innerZipJoin(other)(f) + case _ => + this.withPartitionsRDD( + partitionsRDD.zipPartitions( + other.partitionBy(this.partitioner.get), preservesPartitioning = true) { + (partIter, msgs) => partIter.map(_.innerJoin(msgs)(f)) + } + ) + } + } + + override def aggregateUsingIndex[VD2: ClassTag]( + messages: RDD[(VertexId, VD2)], reduceFunc: (VD2, VD2) => VD2): VertexRDD[VD2] = { + val shuffled = messages.partitionBy(this.partitioner.get) + val parts = partitionsRDD.zipPartitions(shuffled, true) { (thisIter, msgIter) => + thisIter.map(_.aggregateUsingIndex(msgIter, reduceFunc)) + } + this.withPartitionsRDD[VD2](parts) + } + + override def reverseRoutingTables(): VertexRDD[VD] = + this.mapVertexPartitions(vPart => vPart.withRoutingTable(vPart.routingTable.reverse)) + + override def withEdges(edges: EdgeRDD[_]): VertexRDD[VD] = { + val routingTables = VertexRDD.createRoutingTables(edges, this.partitioner.get) + val vertexPartitions = partitionsRDD.zipPartitions(routingTables, true) { + (partIter, routingTableIter) => + val routingTable = + if (routingTableIter.hasNext) routingTableIter.next() else RoutingTablePartition.empty + partIter.map(_.withRoutingTable(routingTable)) + } + this.withPartitionsRDD(vertexPartitions) + } + + override private[graphx] def withPartitionsRDD[VD2: ClassTag]( + partitionsRDD: RDD[ShippableVertexPartition[VD2]]): VertexRDD[VD2] = { + new VertexRDDImpl(partitionsRDD, this.targetStorageLevel) + } + + override private[graphx] def withTargetStorageLevel( + targetStorageLevel: StorageLevel): VertexRDD[VD] = { + new VertexRDDImpl(this.partitionsRDD, targetStorageLevel) + } + + override private[graphx] def shipVertexAttributes( + shipSrc: Boolean, shipDst: Boolean): RDD[(PartitionID, VertexAttributeBlock[VD])] = { + partitionsRDD.mapPartitions(_.flatMap(_.shipVertexAttributes(shipSrc, shipDst))) + } + + override private[graphx] def shipVertexIds(): RDD[(PartitionID, Array[VertexId])] = { + partitionsRDD.mapPartitions(_.flatMap(_.shipVertexIds())) + } + +} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 257e2f3a36115..e139959c3f5c1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -85,7 +85,7 @@ object PageRank extends Logging { // Associate the degree with each vertex .outerJoinVertices(graph.outDegrees) { (vid, vdata, deg) => deg.getOrElse(0) } // Set the weight on the edges based on the degree - .mapTriplets( e => 1.0 / e.srcAttr ) + .mapTriplets( e => 1.0 / e.srcAttr, TripletFields.Src ) // Set the vertex attributes to the initial pagerank values .mapVertices( (id, attr) => resetProb ) @@ -96,8 +96,8 @@ object PageRank extends Logging { // Compute the outgoing rank contributions of each vertex, perform local preaggregation, and // do the final aggregation at the receiving vertices. Requires a shuffle for aggregation. - val rankUpdates = rankGraph.mapReduceTriplets[Double]( - e => Iterator((e.dstId, e.srcAttr * e.attr)), _ + _) + val rankUpdates = rankGraph.aggregateMessages[Double]( + ctx => ctx.sendToDst(ctx.srcAttr * ctx.attr), _ + _, TripletFields.Src) // Apply the final rank updates to get the new ranks, using join to preserve ranks of vertices // that didn't receive a message. Requires a shuffle for broadcasting updated ranks to the diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index ccd7de537b6e3..f58587e10a820 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -74,9 +74,9 @@ object SVDPlusPlus { var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache() // Calculate initial bias and norm - val t0 = g.mapReduceTriplets( - et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))), - (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2)) + val t0 = g.aggregateMessages[(Long, Double)]( + ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) }, + (g1, g2) => (g1._1 + g2._1, g1._2 + g2._2)) g = g.outerJoinVertices(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), @@ -84,15 +84,17 @@ object SVDPlusPlus { (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) } - def mapTrainF(conf: Conf, u: Double) - (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) - : Iterator[(VertexId, (DoubleMatrix, DoubleMatrix, Double))] = { - val (usr, itm) = (et.srcAttr, et.dstAttr) + def sendMsgTrainF(conf: Conf, u: Double) + (ctx: EdgeContext[ + (DoubleMatrix, DoubleMatrix, Double, Double), + Double, + (DoubleMatrix, DoubleMatrix, Double)]) { + val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) - val err = et.attr - pred + val err = ctx.attr - pred val updateP = q.mul(err) .subColumnVector(p.mul(conf.gamma7)) .mul(conf.gamma2) @@ -102,16 +104,16 @@ object SVDPlusPlus { val updateY = q.mul(err * usr._4) .subColumnVector(itm._2.mul(conf.gamma7)) .mul(conf.gamma2) - Iterator((et.srcId, (updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)), - (et.dstId, (updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1))) + ctx.sendToSrc((updateP, updateY, (err - conf.gamma6 * usr._3) * conf.gamma1)) + ctx.sendToDst((updateQ, updateY, (err - conf.gamma6 * itm._3) * conf.gamma1)) } for (i <- 0 until conf.maxIters) { // Phase 1, calculate pu + |N(u)|^(-0.5)*sum(y) for user nodes g.cache() - val t1 = g.mapReduceTriplets( - et => Iterator((et.srcId, et.dstAttr._2)), - (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2)) + val t1 = g.aggregateMessages[DoubleMatrix]( + ctx => ctx.sendToSrc(ctx.dstAttr._2), + (g1, g2) => g1.addColumnVector(g2)) g = g.outerJoinVertices(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => @@ -121,8 +123,8 @@ object SVDPlusPlus { // Phase 2, update p for user nodes and q, y for item nodes g.cache() - val t2 = g.mapReduceTriplets( - mapTrainF(conf, u), + val t2 = g.aggregateMessages( + sendMsgTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) g = g.outerJoinVertices(t2) { @@ -135,20 +137,18 @@ object SVDPlusPlus { } // calculate error on training set - def mapTestF(conf: Conf, u: Double) - (et: EdgeTriplet[(DoubleMatrix, DoubleMatrix, Double, Double), Double]) - : Iterator[(VertexId, Double)] = - { - val (usr, itm) = (et.srcAttr, et.dstAttr) + def sendMsgTestF(conf: Conf, u: Double) + (ctx: EdgeContext[(DoubleMatrix, DoubleMatrix, Double, Double), Double, Double]) { + val (usr, itm) = (ctx.srcAttr, ctx.dstAttr) val (p, q) = (usr._1, itm._1) var pred = u + usr._3 + itm._3 + q.dot(usr._2) pred = math.max(pred, conf.minVal) pred = math.min(pred, conf.maxVal) - val err = (et.attr - pred) * (et.attr - pred) - Iterator((et.dstId, err)) + val err = (ctx.attr - pred) * (ctx.attr - pred) + ctx.sendToDst(err) } g.cache() - val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2) + val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _) g = g.outerJoinVertices(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala index 7c396e6e66a28..daf162085e3e4 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/TriangleCount.scala @@ -61,26 +61,27 @@ object TriangleCount { (vid, _, optSet) => optSet.getOrElse(null) } // Edge function computes intersection of smaller vertex with larger vertex - def edgeFunc(et: EdgeTriplet[VertexSet, ED]): Iterator[(VertexId, Int)] = { - assert(et.srcAttr != null) - assert(et.dstAttr != null) - val (smallSet, largeSet) = if (et.srcAttr.size < et.dstAttr.size) { - (et.srcAttr, et.dstAttr) + def edgeFunc(ctx: EdgeContext[VertexSet, ED, Int]) { + assert(ctx.srcAttr != null) + assert(ctx.dstAttr != null) + val (smallSet, largeSet) = if (ctx.srcAttr.size < ctx.dstAttr.size) { + (ctx.srcAttr, ctx.dstAttr) } else { - (et.dstAttr, et.srcAttr) + (ctx.dstAttr, ctx.srcAttr) } val iter = smallSet.iterator var counter: Int = 0 while (iter.hasNext) { val vid = iter.next() - if (vid != et.srcId && vid != et.dstId && largeSet.contains(vid)) { + if (vid != ctx.srcId && vid != ctx.dstId && largeSet.contains(vid)) { counter += 1 } } - Iterator((et.srcId, counter), (et.dstId, counter)) + ctx.sendToSrc(counter) + ctx.sendToDst(counter) } // compute the intersection along edges - val counters: VertexRDD[Int] = setGraph.mapReduceTriplets(edgeFunc, _ + _) + val counters: VertexRDD[Int] = setGraph.aggregateMessages(edgeFunc, _ + _) // Merge counters with the graph and divide by two since each triangle is counted twice g.outerJoinVertices(counters) { (vid, _, optCounter: Option[Int]) => diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 6506bac73d71c..a05d1ddb21295 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -118,7 +118,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { // Each vertex should be replicated to at most 2 * sqrt(p) partitions val partitionSets = partitionedGraph.edges.partitionsRDD.mapPartitions { iter => val part = iter.next()._2 - Iterator((part.srcIds ++ part.dstIds).toSet) + Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet) }.collect if (!verts.forall(id => partitionSets.count(_.contains(id)) <= bound)) { val numFailures = verts.count(id => partitionSets.count(_.contains(id)) > bound) @@ -130,7 +130,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { // This should not be true for the default hash partitioning val partitionSetsUnpartitioned = graph.edges.partitionsRDD.mapPartitions { iter => val part = iter.next()._2 - Iterator((part.srcIds ++ part.dstIds).toSet) + Iterator((part.iterator.flatMap(e => Iterator(e.srcId, e.dstId))).toSet) }.collect assert(verts.exists(id => partitionSetsUnpartitioned.count(_.contains(id)) > bound)) @@ -318,6 +318,21 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("aggregateMessages") { + withSpark { sc => + val n = 5 + val agg = starGraph(sc, n).aggregateMessages[String]( + ctx => { + if (ctx.dstAttr != null) { + throw new Exception( + "expected ctx.dstAttr to be null due to TripletFields, but it was " + ctx.dstAttr) + } + ctx.sendToDst(ctx.srcAttr) + }, _ + _, TripletFields.Src) + assert(agg.collect().toSet === (1 to n).map(x => (x: VertexId, "v")).toSet) + } + } + test("outerJoinVertices") { withSpark { sc => val n = 5 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala index 47594a800a3b1..a3e28efc75a98 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala @@ -17,9 +17,6 @@ package org.apache.spark.graphx -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterEach - import org.apache.spark.SparkConf import org.apache.spark.SparkContext @@ -31,8 +28,7 @@ trait LocalSparkContext { /** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */ def withSpark[T](f: SparkContext => T) = { val conf = new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator") + GraphXUtils.registerKryoClasses(conf) val sc = new SparkContext("local", "test", conf) try { f(sc) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala deleted file mode 100644 index 864cb1fdf0022..0000000000000 --- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx - -import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream} - -import scala.util.Random -import scala.reflect.ClassTag - -import org.scalatest.FunSuite - -import org.apache.spark._ -import org.apache.spark.graphx.impl._ -import org.apache.spark.serializer.SerializationStream - - -class SerializerSuite extends FunSuite with LocalSparkContext { - - test("IntAggMsgSerializer") { - val outMsg = (4: VertexId, 5) - val bout = new ByteArrayOutputStream - val outStrm = new IntAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new IntAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Int) = inStrm.readObject() - val inMsg2: (VertexId, Int) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("LongAggMsgSerializer") { - val outMsg = (4: VertexId, 1L << 32) - val bout = new ByteArrayOutputStream - val outStrm = new LongAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new LongAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Long) = inStrm.readObject() - val inMsg2: (VertexId, Long) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("DoubleAggMsgSerializer") { - val outMsg = (4: VertexId, 5.0) - val bout = new ByteArrayOutputStream - val outStrm = new DoubleAggMsgSerializer().newInstance().serializeStream(bout) - outStrm.writeObject(outMsg) - outStrm.writeObject(outMsg) - bout.flush() - val bin = new ByteArrayInputStream(bout.toByteArray) - val inStrm = new DoubleAggMsgSerializer().newInstance().deserializeStream(bin) - val inMsg1: (VertexId, Double) = inStrm.readObject() - val inMsg2: (VertexId, Double) = inStrm.readObject() - assert(outMsg === inMsg1) - assert(outMsg === inMsg2) - - intercept[EOFException] { - inStrm.readObject() - } - } - - test("variable long encoding") { - def testVarLongEncoding(v: Long, optimizePositive: Boolean) { - val bout = new ByteArrayOutputStream - val stream = new ShuffleSerializationStream(bout) { - def writeObject[T: ClassTag](t: T): SerializationStream = { - writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive) - this - } - } - stream.writeObject(v) - - val bin = new ByteArrayInputStream(bout.toByteArray) - val dstream = new ShuffleDeserializationStream(bin) { - def readObject[T: ClassTag](): T = { - readVarLong(optimizePositive).asInstanceOf[T] - } - } - val read = dstream.readObject[Long]() - assert(read === v) - } - - // Test all variable encoding code path (each branch uses 7 bits, i.e. 1L << 7 difference) - val d = Random.nextLong() % 128 - Seq[Long](0, 1L << 0 + d, 1L << 7 + d, 1L << 14 + d, 1L << 21 + d, 1L << 28 + d, 1L << 35 + d, - 1L << 42 + d, 1L << 49 + d, 1L << 56 + d, 1L << 63 + d).foreach { number => - testVarLongEncoding(number, optimizePositive = false) - testVarLongEncoding(number, optimizePositive = true) - testVarLongEncoding(-number, optimizePositive = false) - testVarLongEncoding(-number, optimizePositive = true) - } - } -} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 9d00f76327e4c..515f3a9cd02eb 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -82,29 +82,6 @@ class EdgePartitionSuite extends FunSuite { assert(edgePartition.groupEdges(_ + _).iterator.map(_.copy()).toList === groupedEdges) } - test("upgradeIterator") { - val edges = List((0, 1, 0), (1, 0, 0)) - val verts = List((0L, 1), (1L, 2)) - val part = makeEdgePartition(edges).updateVertices(verts.iterator) - assert(part.upgradeIterator(part.iterator).map(_.toTuple).toList === - part.tripletIterator().toList.map(_.toTuple)) - } - - test("indexIterator") { - val edgesFrom0 = List(Edge(0, 1, 0)) - val edgesFrom1 = List(Edge(1, 0, 0), Edge(1, 2, 0)) - val sortedEdges = edgesFrom0 ++ edgesFrom1 - val builder = new EdgePartitionBuilder[Int, Nothing] - for (e <- Random.shuffle(sortedEdges)) { - builder.add(e.srcId, e.dstId, e.attr) - } - - val edgePartition = builder.toEdgePartition - assert(edgePartition.iterator.map(_.copy()).toList === sortedEdges) - assert(edgePartition.indexIterator(_ == 0).map(_.copy()).toList === edgesFrom0) - assert(edgePartition.indexIterator(_ == 1).map(_.copy()).toList === edgesFrom1) - } - test("innerJoin") { val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) val bList = List((0, 1, 0), (1, 0, 0), (1, 1, 0), (3, 4, 0), (5, 5, 0)) @@ -125,21 +102,27 @@ class EdgePartitionSuite extends FunSuite { assert(ep.numActives == Some(2)) } + test("tripletIterator") { + val builder = new EdgePartitionBuilder[Int, Int] + builder.add(1, 2, 0) + builder.add(1, 3, 0) + builder.add(1, 4, 0) + val ep = builder.toEdgePartition + val result = ep.tripletIterator().toList.map(et => (et.srcId, et.dstId)) + assert(result === Seq((1, 2), (1, 3), (1, 4))) + } + test("serialization") { - val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0)) + val aList = List((0, 1, 1), (1, 0, 2), (1, 2, 3), (5, 4, 4), (5, 5, 5)) val a: EdgePartition[Int, Int] = makeEdgePartition(aList) val javaSer = new JavaSerializer(new SparkConf()) - val kryoSer = new KryoSerializer(new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")) + val conf = new SparkConf() + GraphXUtils.registerKryoClasses(conf) + val kryoSer = new KryoSerializer(conf) for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a)) - assert(aSer.srcIds.toList === a.srcIds.toList) - assert(aSer.dstIds.toList === a.dstIds.toList) - assert(aSer.data.toList === a.data.toList) - assert(aSer.index != null) - assert(aSer.vertices.iterator.toSet === a.vertices.iterator.toSet) + assert(aSer.tripletIterator().toList === a.tripletIterator().toList) } } } diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala deleted file mode 100644 index 49b2704390fea..0000000000000 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx.impl - -import scala.reflect.ClassTag -import scala.util.Random - -import org.scalatest.FunSuite - -import org.apache.spark.graphx._ - -class EdgeTripletIteratorSuite extends FunSuite { - test("iterator.toList") { - val builder = new EdgePartitionBuilder[Int, Int] - builder.add(1, 2, 0) - builder.add(1, 3, 0) - builder.add(1, 4, 0) - val iter = new EdgeTripletIterator[Int, Int](builder.toEdgePartition, true, true) - val result = iter.toList.map(et => (et.srcId, et.dstId)) - assert(result === Seq((1, 2), (1, 3), (1, 4))) - } -} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index f9e771a900013..fe8304c1cdc32 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -125,9 +125,9 @@ class VertexPartitionSuite extends FunSuite { val verts = Set((0L, 1), (1L, 1), (2L, 1)) val vp = VertexPartition(verts.iterator) val javaSer = new JavaSerializer(new SparkConf()) - val kryoSer = new KryoSerializer(new SparkConf() - .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")) + val conf = new SparkConf() + GraphXUtils.registerKryoClasses(conf) + val kryoSer = new KryoSerializer(conf) for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) { val vpSer: VertexPartition[Int] = s.deserialize(s.serialize(vp)) diff --git a/make-distribution.sh b/make-distribution.sh index 0bc839e1dbe4d..7c0fb8992a155 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -59,7 +59,7 @@ while (( "$#" )); do exit_with_usage ;; --with-hive) - echo "Error: '--with-hive' is no longer supported, use Maven option -Phive" + echo "Error: '--with-hive' is no longer supported, use Maven options -Phive and -Phive-thriftserver" exit_with_usage ;; --skip-java-test) @@ -119,7 +119,7 @@ VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v " SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) -SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles $@ 2>/dev/null\ +SPARK_HIVE=$(mvn help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ | grep -v "INFO"\ | fgrep --count "hive";\ # Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\ @@ -181,6 +181,9 @@ echo "Spark $VERSION$GITREVSTRING built for Hadoop $SPARK_HADOOP_VERSION" > "$DI # Copy jars cp "$FWDIR"/assembly/target/scala*/*assembly*hadoop*.jar "$DISTDIR/lib/" cp "$FWDIR"/examples/target/scala*/spark-examples*.jar "$DISTDIR/lib/" +# This will fail if the -Pyarn profile is not provided +# In this case, silence the error and ignore the return code of this command +cp "$FWDIR"/network/yarn/target/scala*/spark-*-yarn-shuffle.jar "$DISTDIR/lib/" &> /dev/null || : # Copy example sources (needed for python and SQL) mkdir -p "$DISTDIR/examples/src/main" diff --git a/mllib/pom.xml b/mllib/pom.xml index cfeabe4025de6..878aff66b3728 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -45,6 +45,11 @@ spark-streaming_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server @@ -57,7 +62,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.9 + 0.10 @@ -71,6 +76,10 @@ + + org.apache.commons + commons-math3 + org.scalatest scalatest_${scala.binary.version} @@ -91,6 +100,11 @@ junit-interface test + + org.mockito + mockito-all + test + org.apache.spark spark-streaming_${scala.binary.version} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala new file mode 100644 index 0000000000000..fdbee743e8177 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.api.java.JavaSchemaRDD + +/** + * :: AlphaComponent :: + * Abstract class for estimators that fit models to data. + */ +@AlphaComponent +abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { + + /** + * Fits a single model to the input data with optional parameters. + * + * @param dataset input dataset + * @param paramPairs optional list of param pairs (overwrite embedded params) + * @return fitted model + */ + @varargs + def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + val map = new ParamMap().put(paramPairs: _*) + fit(dataset, map) + } + + /** + * Fits a single model to the input data with provided parameter map. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted model + */ + def fit(dataset: SchemaRDD, paramMap: ParamMap): M + + /** + * Fits multiple models to the input data with multiple sets of parameters. + * The default implementation uses a for loop on each parameter map. + * Subclasses could overwrite this to optimize multi-model training. + * + * @param dataset input dataset + * @param paramMaps an array of parameter maps + * @return fitted models, matching the input parameter maps + */ + def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + paramMaps.map(fit(dataset, _)) + } + + // Java-friendly versions of fit. + + /** + * Fits a single model to the input data with optional parameters. + * + * @param dataset input dataset + * @param paramPairs optional list of param pairs (overwrite embedded params) + * @return fitted model + */ + @varargs + def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = { + fit(dataset.schemaRDD, paramPairs: _*) + } + + /** + * Fits a single model to the input data with provided parameter map. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted model + */ + def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = { + fit(dataset.schemaRDD, paramMap) + } + + /** + * Fits multiple models to the input data with multiple sets of parameters. + * + * @param dataset input dataset + * @param paramMaps an array of parameter maps + * @return fitted models, matching the input parameter maps + */ + def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = { + fit(dataset.schemaRDD, paramMaps).asJava + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala new file mode 100644 index 0000000000000..db563dd550e56 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.SchemaRDD + +/** + * :: AlphaComponent :: + * Abstract class for evaluators that compute metrics from predictions. + */ +@AlphaComponent +abstract class Evaluator extends Identifiable { + + /** + * Evaluates the output. + * + * @param dataset a dataset that contains labels/observations and predictions. + * @param paramMap parameter map that specifies the input columns and output metrics + * @return metric + */ + def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala new file mode 100644 index 0000000000000..cd84b05bfb496 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Identifiable.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import java.util.UUID + +/** + * Object with a unique id. + */ +private[ml] trait Identifiable extends Serializable { + + /** + * A unique id for the object. The default implementation concatenates the class name, "-", and 8 + * random hex chars. + */ + private[ml] val uid: String = + this.getClass.getSimpleName + "-" + UUID.randomUUID().toString.take(8) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Model.scala b/mllib/src/main/scala/org/apache/spark/ml/Model.scala new file mode 100644 index 0000000000000..cae5082b51196 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Model.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.ParamMap + +/** + * :: AlphaComponent :: + * A fitted model, i.e., a [[Transformer]] produced by an [[Estimator]]. + * + * @tparam M model type + */ +@AlphaComponent +abstract class Model[M <: Model[M]] extends Transformer { + /** + * The parent estimator that produced this model. + */ + val parent: Estimator[M] + + /** + * Fitting parameters, such that parent.fit(..., fittingParamMap) could reproduce the model. + */ + val fittingParamMap: ParamMap +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala new file mode 100644 index 0000000000000..e545df1e37b9c --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.collection.mutable.ListBuffer + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param.{Params, Param, ParamMap} +import org.apache.spark.sql.{SchemaRDD, StructType} + +/** + * :: AlphaComponent :: + * A stage in a pipeline, either an [[Estimator]] or a [[Transformer]]. + */ +@AlphaComponent +abstract class PipelineStage extends Serializable with Logging { + + /** + * Derives the output schema from the input schema and parameters. + */ + private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType + + /** + * Derives the output schema from the input schema and parameters, optionally with logging. + */ + protected def transformSchema( + schema: StructType, + paramMap: ParamMap, + logging: Boolean): StructType = { + if (logging) { + logDebug(s"Input schema: ${schema.json}") + } + val outputSchema = transformSchema(schema, paramMap) + if (logging) { + logDebug(s"Expected output schema: ${outputSchema.json}") + } + outputSchema + } +} + +/** + * :: AlphaComponent :: + * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each + * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the + * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will + * be called on the input dataset to fit a model. Then the model, which is a transformer, will be + * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]], + * its [[Transformer.transform]] method will be called to produce the dataset for the next stage. + * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and + * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as + * an identity transformer. + */ +@AlphaComponent +class Pipeline extends Estimator[PipelineModel] { + + /** param for pipeline stages */ + val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline") + def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this } + def getStages: Array[PipelineStage] = get(stages) + + /** + * Fits the pipeline to the input dataset with additional parameters. If a stage is an + * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model. + * Then the model, which is a transformer, will be used to transform the dataset as the input to + * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be + * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an + * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the + * pipeline stages. If there are no stages, the output model acts as an identity transformer. + * + * @param dataset input dataset + * @param paramMap parameter map + * @return fitted pipeline + */ + override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + transformSchema(dataset.schema, paramMap, logging = true) + val map = this.paramMap ++ paramMap + val theStages = map(stages) + // Search for the last estimator. + var indexOfLastEstimator = -1 + theStages.view.zipWithIndex.foreach { case (stage, index) => + stage match { + case _: Estimator[_] => + indexOfLastEstimator = index + case _ => + } + } + var curDataset = dataset + val transformers = ListBuffer.empty[Transformer] + theStages.view.zipWithIndex.foreach { case (stage, index) => + if (index <= indexOfLastEstimator) { + val transformer = stage match { + case estimator: Estimator[_] => + estimator.fit(curDataset, paramMap) + case t: Transformer => + t + case _ => + throw new IllegalArgumentException( + s"Do not support stage $stage of type ${stage.getClass}") + } + curDataset = transformer.transform(curDataset, paramMap) + transformers += transformer + } else { + transformers += stage.asInstanceOf[Transformer] + } + } + + new PipelineModel(this, map, transformers.toArray) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val theStages = map(stages) + require(theStages.toSet.size == theStages.size, + "Cannot have duplicate components in a pipeline.") + theStages.foldLeft(schema)((cur, stage) => stage.transformSchema(cur, paramMap)) + } +} + +/** + * :: AlphaComponent :: + * Represents a compiled pipeline. + */ +@AlphaComponent +class PipelineModel private[ml] ( + override val parent: Pipeline, + override val fittingParamMap: ParamMap, + private[ml] val stages: Array[Transformer]) + extends Model[PipelineModel] with Logging { + + /** + * Gets the model produced by the input estimator. Throws an NoSuchElementException is the input + * estimator does not exist in the pipeline. + */ + def getModel[M <: Model[M]](stage: Estimator[M]): M = { + val matched = stages.filter { + case m: Model[_] => m.parent.eq(stage) + case _ => false + } + if (matched.isEmpty) { + throw new NoSuchElementException(s"Cannot find stage $stage from the pipeline.") + } else if (matched.size > 1) { + throw new IllegalStateException(s"Cannot have duplicate estimators in the sample pipeline.") + } else { + matched.head.asInstanceOf[M] + } + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap)) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala new file mode 100644 index 0000000000000..490e6609ad311 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.annotation.varargs +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ +import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.api.java.JavaSchemaRDD +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.catalyst.types._ + +/** + * :: AlphaComponent :: + * Abstract class for transformers that transform one dataset into another. + */ +@AlphaComponent +abstract class Transformer extends PipelineStage with Params { + + /** + * Transforms the dataset with optional parameters + * @param dataset input dataset + * @param paramPairs optional list of param pairs, overwrite embedded params + * @return transformed dataset + */ + @varargs + def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + val map = new ParamMap() + paramPairs.foreach(map.put(_)) + transform(dataset, map) + } + + /** + * Transforms the dataset with provided parameter map as additional parameters. + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + + // Java-friendly versions of transform. + + /** + * Transforms the dataset with optional parameters. + * @param dataset input datset + * @param paramPairs optional list of param pairs, overwrite embedded params + * @return transformed dataset + */ + @varargs + def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = { + transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD + } + + /** + * Transforms the dataset with provided parameter map as additional parameters. + * @param dataset input dataset + * @param paramMap additional parameters, overwrite embedded params + * @return transformed dataset + */ + def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = { + transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD + } +} + +/** + * Abstract class for transformers that take one input column, apply transformation, and output the + * result as a new column. + */ +private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]] + extends Transformer with HasInputCol with HasOutputCol with Logging { + + def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T] + def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T] + + /** + * Creates the transform function using the given param map. The input param map already takes + * account of the embedded param map. So the param values should be determined solely by the input + * param map. + */ + protected def createTransformFunc(paramMap: ParamMap): IN => OUT + + /** + * Validates the input type. Throw an exception if it is invalid. + */ + protected def validateInputType(inputType: DataType): Unit = {} + + override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + validateInputType(inputType) + if (schema.fieldNames.contains(map(outputCol))) { + throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.") + } + val output = ScalaReflection.schemaFor[OUT] + val outputFields = schema.fields :+ + StructField(map(outputCol), output.dataType, output.nullable) + StructType(outputFields) + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val udf = this.createTransformFunc(map) + dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala new file mode 100644 index 0000000000000..85b8899636ca5 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.storage.StorageLevel + +/** + * :: AlphaComponent :: + * Params for logistic regression. + */ +@AlphaComponent +private[classification] trait LogisticRegressionParams extends Params + with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol + with HasScoreCol with HasPredictionCol { + + /** + * Validates and transforms the input schema with the provided param map. + * @param schema input schema + * @param paramMap additional parameters + * @param fitting whether this is in fitting + * @return output schema + */ + protected def validateAndTransformSchema( + schema: StructType, + paramMap: ParamMap, + fitting: Boolean): StructType = { + val map = this.paramMap ++ paramMap + val featuresType = schema(map(featuresCol)).dataType + // TODO: Support casting Array[Double] and Array[Float] to Vector. + require(featuresType.isInstanceOf[VectorUDT], + s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.") + if (fitting) { + val labelType = schema(map(labelCol)).dataType + require(labelType == DoubleType, + s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.") + } + val fieldNames = schema.fieldNames + require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.") + require(!fieldNames.contains(map(predictionCol)), + s"Prediction column ${map(predictionCol)} already exists.") + val outputFields = schema.fields ++ Seq( + StructField(map(scoreCol), DoubleType, false), + StructField(map(predictionCol), DoubleType, false)) + StructType(outputFields) + } +} + +/** + * Logistic regression. + */ +class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams { + + setRegParam(0.1) + setMaxIter(100) + setThreshold(0.5) + + def setRegParam(value: Double): this.type = set(regParam, value) + def setMaxIter(value: Int): this.type = set(maxIter, value) + def setLabelCol(value: String): this.type = set(labelCol, value) + def setThreshold(value: Double): this.type = set(threshold, value) + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + .map { case Row(label: Double, features: Vector) => + LabeledPoint(label, features) + }.persist(StorageLevel.MEMORY_AND_DISK) + val lr = new LogisticRegressionWithLBFGS + lr.optimizer + .setRegParam(map(regParam)) + .setNumIterations(map(maxIter)) + val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights) + instances.unpersist() + // copy model params + Params.inheritValues(map, this, lrm) + lrm + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = true) + } +} + +/** + * :: AlphaComponent :: + * Model produced by [[LogisticRegression]]. + */ +@AlphaComponent +class LogisticRegressionModel private[ml] ( + override val parent: LogisticRegression, + override val fittingParamMap: ParamMap, + weights: Vector) + extends Model[LogisticRegressionModel] with LogisticRegressionParams { + + def setThreshold(value: Double): this.type = set(threshold, value) + def setFeaturesCol(value: String): this.type = set(featuresCol, value) + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setPredictionCol(value: String): this.type = set(predictionCol, value) + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + validateAndTransformSchema(schema, paramMap, fitting = false) + } + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val score: Vector => Double = (v) => { + val margin = BLAS.dot(v, weights) + 1.0 / (1.0 + math.exp(-margin)) + } + val t = map(threshold) + val predict: Double => Double = (score) => { + if (score > t) 1.0 else 0.0 + } + dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) + .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala new file mode 100644 index 0000000000000..0b0504e036ec9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.evaluation + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.sql.{DoubleType, Row, SchemaRDD} + +/** + * :: AlphaComponent :: + * Evaluator for binary classification, which expects two input columns: score and label. + */ +@AlphaComponent +class BinaryClassificationEvaluator extends Evaluator with Params + with HasScoreCol with HasLabelCol { + + /** param for metric name in evaluation */ + val metricName: Param[String] = new Param(this, "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC")) + def getMetricName: String = get(metricName) + def setMetricName(value: String): this.type = set(metricName, value) + + def setScoreCol(value: String): this.type = set(scoreCol, value) + def setLabelCol(value: String): this.type = set(labelCol, value) + + override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + val map = this.paramMap ++ paramMap + + val schema = dataset.schema + val scoreType = schema(map(scoreCol)).dataType + require(scoreType == DoubleType, + s"Score column ${map(scoreCol)} must be double type but found $scoreType") + val labelType = schema(map(labelCol)).dataType + require(labelType == DoubleType, + s"Label column ${map(labelCol)} must be double type but found $labelType") + + import dataset.sqlContext._ + val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) + .map { case Row(score: Double, label: Double) => + (score, label) + } + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val metric = map(metricName) match { + case "areaUnderROC" => + metrics.areaUnderROC() + case "areaUnderPR" => + metrics.areaUnderPR() + case other => + throw new IllegalArgumentException(s"Does not support metric $other.") + } + metrics.unpersist() + metric + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala new file mode 100644 index 0000000000000..b98b1755a3584 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.{IntParam, ParamMap} +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.Vector + +/** + * :: AlphaComponent :: + * Maps a sequence of terms to their term frequencies using the hashing trick. + */ +@AlphaComponent +class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] { + + /** number of features */ + val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18)) + def setNumFeatures(value: Int) = set(numFeatures, value) + def getNumFeatures: Int = get(numFeatures) + + override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = { + val hashingTF = new feature.HashingTF(paramMap(numFeatures)) + hashingTF.transform + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala new file mode 100644 index 0000000000000..896a6b83b67bf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.dsl._ + +/** + * Params for [[StandardScaler]] and [[StandardScalerModel]]. + */ +private[feature] trait StandardScalerParams extends Params with HasInputCol with HasOutputCol + +/** + * :: AlphaComponent :: + * Standardizes features by removing the mean and scaling to unit variance using column summary + * statistics on the samples in the training set. + */ +@AlphaComponent +class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams { + + def setInputCol(value: String): this.type = set(inputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val input = dataset.select(map(inputCol).attr) + .map { case Row(v: Vector) => + v + } + val scaler = new feature.StandardScaler().fit(input) + val model = new StandardScalerModel(this, map, scaler) + Params.inheritValues(map, this, model) + model + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} + +/** + * :: AlphaComponent :: + * Model fitted by [[StandardScaler]]. + */ +@AlphaComponent +class StandardScalerModel private[ml] ( + override val parent: StandardScaler, + override val fittingParamMap: ParamMap, + scaler: feature.StandardScalerModel) + extends Model[StandardScalerModel] with StandardScalerParams { + + def setInputCol(value: String): this.type = set(inputCol, value) + def setOutputCol(value: String): this.type = set(outputCol, value) + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + transformSchema(dataset.schema, paramMap, logging = true) + import dataset.sqlContext._ + val map = this.paramMap ++ paramMap + val scale: (Vector) => Vector = (v) => { + scaler.transform(v) + } + dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + val inputType = schema(map(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${map(inputCol)} must be a vector column") + require(!schema.fieldNames.contains(map(outputCol)), + s"Output column ${map(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField(map(outputCol), new VectorUDT, false) + StructType(outputFields) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala new file mode 100644 index 0000000000000..0a6599b64c011 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.{DataType, StringType} + +/** + * :: AlphaComponent :: + * A tokenizer that converts the input string to lowercase and then splits it by white spaces. + */ +@AlphaComponent +class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] { + + protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = { + _.toLowerCase.split("\\s") + } + + protected override def validateInputType(inputType: DataType): Unit = { + require(inputType == StringType, s"Input type must be string type but got $inputType.") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/package-info.java b/mllib/src/main/scala/org/apache/spark/ml/package-info.java new file mode 100644 index 0000000000000..00d9c802e930d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/package-info.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * assemble and configure practical machine learning pipelines. + */ +@AlphaComponent +package org.apache.spark.ml; + +import org.apache.spark.annotation.AlphaComponent; diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala new file mode 100644 index 0000000000000..51cd48c90432a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +/** + * Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly + * assemble and configure practical machine learning pipelines. + */ +package object ml diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala new file mode 100644 index 0000000000000..8fd46aef4b99d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +import java.lang.reflect.Modifier + +import org.apache.spark.annotation.AlphaComponent + +import scala.annotation.varargs +import scala.collection.mutable + +import org.apache.spark.ml.Identifiable + +/** + * :: AlphaComponent :: + * A param with self-contained documentation and optionally default value. Primitive-typed param + * should use the specialized versions, which are more friendly to Java users. + * + * @param parent parent object + * @param name param name + * @param doc documentation + * @tparam T param value type + */ +@AlphaComponent +class Param[T] ( + val parent: Params, + val name: String, + val doc: String, + val defaultValue: Option[T] = None) + extends Serializable { + + /** + * Creates a param pair with the given value (for Java). + */ + def w(value: T): ParamPair[T] = this -> value + + /** + * Creates a param pair with the given value (for Scala). + */ + def ->(value: T): ParamPair[T] = ParamPair(this, value) + + override def toString: String = { + if (defaultValue.isDefined) { + s"$name: $doc (default: ${defaultValue.get})" + } else { + s"$name: $doc" + } + } +} + +// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ... + +/** Specialized version of [[Param[Double]]] for Java. */ +class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None) + extends Param[Double](parent, name, doc, defaultValue) { + + override def w(value: Double): ParamPair[Double] = super.w(value) +} + +/** Specialized version of [[Param[Int]]] for Java. */ +class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None) + extends Param[Int](parent, name, doc, defaultValue) { + + override def w(value: Int): ParamPair[Int] = super.w(value) +} + +/** Specialized version of [[Param[Float]]] for Java. */ +class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None) + extends Param[Float](parent, name, doc, defaultValue) { + + override def w(value: Float): ParamPair[Float] = super.w(value) +} + +/** Specialized version of [[Param[Long]]] for Java. */ +class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None) + extends Param[Long](parent, name, doc, defaultValue) { + + override def w(value: Long): ParamPair[Long] = super.w(value) +} + +/** Specialized version of [[Param[Boolean]]] for Java. */ +class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None) + extends Param[Boolean](parent, name, doc, defaultValue) { + + override def w(value: Boolean): ParamPair[Boolean] = super.w(value) +} + +/** + * A param amd its value. + */ +case class ParamPair[T](param: Param[T], value: T) + +/** + * :: AlphaComponent :: + * Trait for components that take parameters. This also provides an internal param map to store + * parameter values attached to the instance. + */ +@AlphaComponent +trait Params extends Identifiable with Serializable { + + /** Returns all params. */ + def params: Array[Param[_]] = { + val methods = this.getClass.getMethods + methods.filter { m => + Modifier.isPublic(m.getModifiers) && + classOf[Param[_]].isAssignableFrom(m.getReturnType) && + m.getParameterTypes.isEmpty + }.sortBy(_.getName) + .map(m => m.invoke(this).asInstanceOf[Param[_]]) + } + + /** + * Validates parameter values stored internally plus the input parameter map. + * Raises an exception if any parameter is invalid. + */ + def validate(paramMap: ParamMap): Unit = {} + + /** + * Validates parameter values stored internally. + * Raise an exception if any parameter value is invalid. + */ + def validate(): Unit = validate(ParamMap.empty) + + /** + * Returns the documentation of all params. + */ + def explainParams(): String = params.mkString("\n") + + /** Checks whether a param is explicitly set. */ + def isSet(param: Param[_]): Boolean = { + require(param.parent.eq(this)) + paramMap.contains(param) + } + + /** Gets a param by its name. */ + private[ml] def getParam(paramName: String): Param[Any] = { + val m = this.getClass.getMethod(paramName) + assert(Modifier.isPublic(m.getModifiers) && + classOf[Param[_]].isAssignableFrom(m.getReturnType) && + m.getParameterTypes.isEmpty) + m.invoke(this).asInstanceOf[Param[Any]] + } + + /** + * Sets a parameter in the embedded param map. + */ + private[ml] def set[T](param: Param[T], value: T): this.type = { + require(param.parent.eq(this)) + paramMap.put(param.asInstanceOf[Param[Any]], value) + this + } + + /** + * Gets the value of a parameter in the embedded param map. + */ + private[ml] def get[T](param: Param[T]): T = { + require(param.parent.eq(this)) + paramMap(param) + } + + /** + * Internal param map. + */ + protected val paramMap: ParamMap = ParamMap.empty +} + +private[ml] object Params { + + /** + * Copies parameter values from the parent estimator to the child model it produced. + * @param paramMap the param map that holds parameters of the parent + * @param parent the parent estimator + * @param child the child model + */ + def inheritValues[E <: Params, M <: E]( + paramMap: ParamMap, + parent: E, + child: M): Unit = { + parent.params.foreach { param => + if (paramMap.contains(param)) { + child.set(child.getParam(param.name), paramMap(param)) + } + } + } +} + +/** + * :: AlphaComponent :: + * A param to value map. + */ +@AlphaComponent +class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) extends Serializable { + + /** + * Creates an empty param map. + */ + def this() = this(mutable.Map.empty[Param[Any], Any]) + + /** + * Puts a (param, value) pair (overwrites if the input param exists). + */ + def put[T](param: Param[T], value: T): this.type = { + map(param.asInstanceOf[Param[Any]]) = value + this + } + + /** + * Puts a list of param pairs (overwrites if the input params exists). + */ + def put(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + put(p.param.asInstanceOf[Param[Any]], p.value) + } + this + } + + /** + * Optionally returns the value associated with a param or its default. + */ + def get[T](param: Param[T]): Option[T] = { + map.get(param.asInstanceOf[Param[Any]]) + .orElse(param.defaultValue) + .asInstanceOf[Option[T]] + } + + /** + * Gets the value of the input param or its default value if it does not exist. + * Raises a NoSuchElementException if there is no value associated with the input param. + */ + def apply[T](param: Param[T]): T = { + val value = get(param) + if (value.isDefined) { + value.get + } else { + throw new NoSuchElementException(s"Cannot find param ${param.name}.") + } + } + + /** + * Checks whether a parameter is explicitly specified. + */ + def contains(param: Param[_]): Boolean = { + map.contains(param.asInstanceOf[Param[Any]]) + } + + /** + * Filters this param map for the given parent. + */ + def filter(parent: Params): ParamMap = { + val filtered = map.filterKeys(_.parent == parent) + new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]]) + } + + /** + * Make a copy of this param map. + */ + def copy: ParamMap = new ParamMap(map.clone()) + + override def toString: String = { + map.map { case (param, value) => + s"\t${param.parent.uid}-${param.name}: $value" + }.mkString("{\n", ",\n", "\n}") + } + + /** + * Returns a new param map that contains parameters in this map and the given map, + * where the latter overwrites this if there exists conflicts. + */ + def ++(other: ParamMap): ParamMap = { + new ParamMap(this.map ++ other.map) + } + + + /** + * Adds all parameters from the input param map into this param map. + */ + def ++=(other: ParamMap): this.type = { + this.map ++= other.map + this + } + + /** + * Converts this param map to a sequence of param pairs. + */ + def toSeq: Seq[ParamPair[_]] = { + map.toSeq.map { case (param, value) => + ParamPair(param, value) + } + } +} + +object ParamMap { + + /** + * Returns an empty param map. + */ + def empty: ParamMap = new ParamMap() + + /** + * Constructs a param map by specifying its entries. + */ + @varargs + def apply(paramPairs: ParamPair[_]*): ParamMap = { + new ParamMap().put(paramPairs: _*) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala new file mode 100644 index 0000000000000..ef141d3eb2b06 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +private[ml] trait HasRegParam extends Params { + /** param for regularization parameter */ + val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter") + def getRegParam: Double = get(regParam) +} + +private[ml] trait HasMaxIter extends Params { + /** param for max number of iterations */ + val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations") + def getMaxIter: Int = get(maxIter) +} + +private[ml] trait HasFeaturesCol extends Params { + /** param for features column name */ + val featuresCol: Param[String] = + new Param(this, "featuresCol", "features column name", Some("features")) + def getFeaturesCol: String = get(featuresCol) +} + +private[ml] trait HasLabelCol extends Params { + /** param for label column name */ + val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label")) + def getLabelCol: String = get(labelCol) +} + +private[ml] trait HasScoreCol extends Params { + /** param for score column name */ + val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score")) + def getScoreCol: String = get(scoreCol) +} + +private[ml] trait HasPredictionCol extends Params { + /** param for prediction column name */ + val predictionCol: Param[String] = + new Param(this, "predictionCol", "prediction column name", Some("prediction")) + def getPredictionCol: String = get(predictionCol) +} + +private[ml] trait HasThreshold extends Params { + /** param for threshold in (binary) prediction */ + val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction") + def getThreshold: Double = get(threshold) +} + +private[ml] trait HasInputCol extends Params { + /** param for input column name */ + val inputCol: Param[String] = new Param(this, "inputCol", "input column name") + def getInputCol: String = get(inputCol) +} + +private[ml] trait HasOutputCol extends Params { + /** param for output column name */ + val outputCol: Param[String] = new Param(this, "outputCol", "output column name") + def getOutputCol: String = get(outputCol) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala new file mode 100644 index 0000000000000..194b9bfd9a9e6 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import com.github.fommil.netlib.F2jBLAS + +import org.apache.spark.Logging +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml._ +import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.sql.{SchemaRDD, StructType} + +/** + * Params for [[CrossValidator]] and [[CrossValidatorModel]]. + */ +private[ml] trait CrossValidatorParams extends Params { + /** param for the estimator to be cross-validated */ + val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection") + def getEstimator: Estimator[_] = get(estimator) + + /** param for estimator param maps */ + val estimatorParamMaps: Param[Array[ParamMap]] = + new Param(this, "estimatorParamMaps", "param maps for the estimator") + def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps) + + /** param for the evaluator for selection */ + val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection") + def getEvaluator: Evaluator = get(evaluator) + + /** param for number of folds for cross validation */ + val numFolds: IntParam = + new IntParam(this, "numFolds", "number of folds for cross validation", Some(3)) + def getNumFolds: Int = get(numFolds) +} + +/** + * :: AlphaComponent :: + * K-fold cross validation. + */ +@AlphaComponent +class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorParams with Logging { + + private val f2jBLAS = new F2jBLAS + + def setEstimator(value: Estimator[_]): this.type = set(estimator, value) + def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value) + def setEvaluator(value: Evaluator): this.type = set(evaluator, value) + def setNumFolds(value: Int): this.type = set(numFolds, value) + + override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + val map = this.paramMap ++ paramMap + val schema = dataset.schema + transformSchema(dataset.schema, paramMap, logging = true) + val sqlCtx = dataset.sqlContext + val est = map(estimator) + val eval = map(evaluator) + val epm = map(estimatorParamMaps) + val numModels = epm.size + val metrics = new Array[Double](epm.size) + val splits = MLUtils.kFold(dataset, map(numFolds), 0) + splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => + val trainingDataset = sqlCtx.applySchema(training, schema).cache() + val validationDataset = sqlCtx.applySchema(validation, schema).cache() + // multi-model training + logDebug(s"Train split $splitIndex with multiple sets of parameters.") + val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + var i = 0 + while (i < numModels) { + val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) + logDebug(s"Got metric $metric for model trained with ${epm(i)}.") + metrics(i) += metric + i += 1 + } + } + f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) + logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") + val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + logInfo(s"Best set of parameters:\n${epm(bestIndex)}") + logInfo(s"Best cross-validation metric: $bestMetric.") + val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] + val cvModel = new CrossValidatorModel(this, map, bestModel) + Params.inheritValues(map, this, cvModel) + cvModel + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + val map = this.paramMap ++ paramMap + map(estimator).transformSchema(schema, paramMap) + } +} + +/** + * :: AlphaComponent :: + * Model from k-fold cross validation. + */ +@AlphaComponent +class CrossValidatorModel private[ml] ( + override val parent: CrossValidator, + override val fittingParamMap: ParamMap, + val bestModel: Model[_]) + extends Model[CrossValidatorModel] with CrossValidatorParams { + + override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + bestModel.transform(dataset, paramMap) + } + + private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { + bestModel.transformSchema(schema, paramMap) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala new file mode 100644 index 0000000000000..dafe73d82c00a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ParamGridBuilder.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.annotation.varargs +import scala.collection.mutable + +import org.apache.spark.annotation.AlphaComponent +import org.apache.spark.ml.param._ + +/** + * :: AlphaComponent :: + * Builder for a param grid used in grid search-based model selection. + */ +@AlphaComponent +class ParamGridBuilder { + + private val paramGrid = mutable.Map.empty[Param[_], Iterable[_]] + + /** + * Sets the given parameters in this grid to fixed values. + */ + def baseOn(paramMap: ParamMap): this.type = { + baseOn(paramMap.toSeq: _*) + this + } + + /** + * Sets the given parameters in this grid to fixed values. + */ + @varargs + def baseOn(paramPairs: ParamPair[_]*): this.type = { + paramPairs.foreach { p => + addGrid(p.param.asInstanceOf[Param[Any]], Seq(p.value)) + } + this + } + + /** + * Adds a param with multiple values (overwrites if the input param exists). + */ + def addGrid[T](param: Param[T], values: Iterable[T]): this.type = { + paramGrid.put(param, values) + this + } + + // specialized versions of addGrid for Java. + + /** + * Adds a double param with multiple values. + */ + def addGrid(param: DoubleParam, values: Array[Double]): this.type = { + addGrid[Double](param, values) + } + + /** + * Adds a int param with multiple values. + */ + def addGrid(param: IntParam, values: Array[Int]): this.type = { + addGrid[Int](param, values) + } + + /** + * Adds a float param with multiple values. + */ + def addGrid(param: FloatParam, values: Array[Float]): this.type = { + addGrid[Float](param, values) + } + + /** + * Adds a long param with multiple values. + */ + def addGrid(param: LongParam, values: Array[Long]): this.type = { + addGrid[Long](param, values) + } + + /** + * Adds a boolean param with true and false. + */ + def addGrid(param: BooleanParam): this.type = { + addGrid[Boolean](param, Array(true, false)) + } + + /** + * Builds and returns all combinations of parameters specified by the param grid. + */ + def build(): Array[ParamMap] = { + var paramMaps = Array(new ParamMap) + paramGrid.foreach { case (param, values) => + val newParamMaps = values.flatMap { v => + paramMaps.map(_.copy.put(param.asInstanceOf[Param[Any]], v)) + } + paramMaps = newParamMaps.toArray + } + paramMaps + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index e9f41758581e3..9f20cd5d00dcd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -18,6 +18,8 @@ package org.apache.spark.mllib.api.python import java.io.OutputStream +import java.nio.{ByteBuffer, ByteOrder} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.language.existentials @@ -27,24 +29,27 @@ import net.razorvine.pickle._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.api.python.SerDeUtil import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.optimization._ +import org.apache.spark.mllib.feature._ import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} -import org.apache.spark.mllib.tree.DecisionTree -import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames +import org.apache.spark.mllib.stat.test.ChiSqTestResult +import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} +import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.impurity._ +import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils - /** * :: DeveloperApi :: * The Java stubs necessary for the Python mllib bindings. @@ -69,15 +74,29 @@ class PythonMLLibAPI extends Serializable { private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], - initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector] - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - learner.disableUncachedWarning() - val model = learner.run(data.rdd, initialWeights) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(SerDe.dumps(model.weights)) - ret.add(model.intercept: java.lang.Double) - ret + initialWeights: Vector): JList[Object] = { + try { + val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + } finally { + data.rdd.unpersist(blocking = false) + } + } + + /** + * Return the Updater from string + */ + def getUpdaterFromString(regType: String): Updater = { + if (regType == "l2") { + new SquaredL2Updater + } else if (regType == "l1") { + new L1Updater + } else if (regType == null || regType == "none") { + new SimpleUpdater + } else { + throw new IllegalArgumentException("Invalid value for 'regType' parameter." + + " Can only be initialized using the following string values: ['l1', 'l2', None].") + } } /** @@ -88,10 +107,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val lrAlg = new LinearRegressionWithSGD() lrAlg.setIntercept(intercept) lrAlg.optimizer @@ -99,18 +118,11 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - lrAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - lrAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") - } + lrAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( lrAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -122,7 +134,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val lassoAlg = new LassoWithSGD() lassoAlg.optimizer .setNumIterations(numIterations) @@ -132,7 +144,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( lassoAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -144,7 +156,7 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + initialWeights: Vector): JList[Object] = { val ridgeAlg = new RidgeRegressionWithSGD() ridgeAlg.optimizer .setNumIterations(numIterations) @@ -154,7 +166,7 @@ class PythonMLLibAPI extends Serializable { trainRegressionModel( ridgeAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -166,9 +178,9 @@ class PythonMLLibAPI extends Serializable { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val SVMAlg = new SVMWithSGD() SVMAlg.setIntercept(intercept) SVMAlg.optimizer @@ -176,18 +188,11 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - SVMAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - SVMAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") - } + SVMAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( SVMAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -198,10 +203,10 @@ class PythonMLLibAPI extends Serializable { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeightsBA: Array[Byte], + initialWeights: Vector, regParam: Double, regType: String, - intercept: Boolean): java.util.List[java.lang.Object] = { + intercept: Boolean): JList[Object] = { val LogRegAlg = new LogisticRegressionWithSGD() LogRegAlg.setIntercept(intercept) LogRegAlg.optimizer @@ -209,18 +214,37 @@ class PythonMLLibAPI extends Serializable { .setRegParam(regParam) .setStepSize(stepSize) .setMiniBatchFraction(miniBatchFraction) - if (regType == "l2") { - LogRegAlg.optimizer.setUpdater(new SquaredL2Updater) - } else if (regType == "l1") { - LogRegAlg.optimizer.setUpdater(new L1Updater) - } else if (regType != "none") { - throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter." - + " Can only be initialized using the following string values: [l1, l2, none].") - } + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) + trainRegressionModel( + LogRegAlg, + data, + initialWeights) + } + + /** + * Java stub for Python mllib LogisticRegressionWithLBFGS.train() + */ + def trainLogisticRegressionModelWithLBFGS( + data: JavaRDD[LabeledPoint], + numIterations: Int, + initialWeights: Vector, + regParam: Double, + regType: String, + intercept: Boolean, + corrections: Int, + tolerance: Double): JList[Object] = { + val LogRegAlg = new LogisticRegressionWithLBFGS() + LogRegAlg.setIntercept(intercept) + LogRegAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setNumCorrections(corrections) + .setConvergenceTol(tolerance) + LogRegAlg.optimizer.setUpdater(getUpdaterFromString(regType)) trainRegressionModel( LogRegAlg, data, - initialWeightsBA) + initialWeights) } /** @@ -228,13 +252,10 @@ class PythonMLLibAPI extends Serializable { */ def trainNaiveBayes( data: JavaRDD[LabeledPoint], - lambda: Double): java.util.List[java.lang.Object] = { + lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) - val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(Vectors.dense(model.labels)) - ret.add(Vectors.dense(model.pi)) - ret.add(model.theta) - ret + List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta). + map(_.asInstanceOf[Object]).asJava } /** @@ -251,9 +272,26 @@ class PythonMLLibAPI extends Serializable { .setMaxIterations(maxIterations) .setRuns(runs) .setInitializationMode(initializationMode) - // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. - .disableUncachedWarning() - return kMeansAlg.run(data.rdd) + try { + kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK)) + } finally { + data.rdd.unpersist(blocking = false) + } + } + + /** + * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python + */ + private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel) + extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) { + + def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] = + predict(SerDe.asTupleRDD(userAndProducts.rdd)) + + def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]]) + + def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]]) + } /** @@ -263,12 +301,25 @@ class PythonMLLibAPI extends Serializable { * the Py4J documentation. */ def trainALSModel( - ratings: JavaRDD[Rating], + ratingsJRDD: JavaRDD[Rating], rank: Int, iterations: Int, lambda: Double, - blocks: Int): MatrixFactorizationModel = { - ALS.train(ratings.rdd, rank, iterations, lambda, blocks) + blocks: Int, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) } /** @@ -283,8 +334,121 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int, - alpha: Double): MatrixFactorizationModel = { - ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha) + alpha: Double, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setImplicitPrefs(true) + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setAlpha(alpha) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) + } + + /** + * Java stub for Normalizer.transform() + */ + def normalizeVector(p: Double, vector: Vector): Vector = { + new Normalizer(p).transform(vector) + } + + /** + * Java stub for Normalizer.transform() + */ + def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = { + new Normalizer(p).transform(rdd) + } + + /** + * Java stub for IDF.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitStandardScaler( + withMean: Boolean, + withStd: Boolean, + data: JavaRDD[Vector]): StandardScalerModel = { + new StandardScaler(withMean, withStd).fit(data.rdd) + } + + /** + * Java stub for IDF.fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + */ + def fitIDF(minDocFreq: Int, dataset: JavaRDD[Vector]): IDFModel = { + new IDF(minDocFreq).fit(dataset) + } + + /** + * Java stub for Python mllib Word2Vec fit(). This stub returns a + * handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on + * exit; see the Py4J documentation. + * @param dataJRDD input JavaRDD + * @param vectorSize size of vector + * @param learningRate initial learning rate + * @param numPartitions number of partitions + * @param numIterations number of iterations + * @param seed initial seed for random generator + * @return A handle to java Word2VecModelWrapper instance at python side + */ + def trainWord2Vec( + dataJRDD: JavaRDD[java.util.ArrayList[String]], + vectorSize: Int, + learningRate: Double, + numPartitions: Int, + numIterations: Int, + seed: Long): Word2VecModelWrapper = { + val word2vec = new Word2Vec() + .setVectorSize(vectorSize) + .setLearningRate(learningRate) + .setNumPartitions(numPartitions) + .setNumIterations(numIterations) + .setSeed(seed) + try { + val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER)) + new Word2VecModelWrapper(model) + } finally { + dataJRDD.rdd.unpersist(blocking = false) + } + } + + private[python] class Word2VecModelWrapper(model: Word2VecModel) { + def transform(word: String): Vector = { + model.transform(word) + } + + /** + * Transforms an RDD of words to its vector representation + * @param rdd an RDD of words + * @return an RDD of vector representations of words + */ + def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = { + rdd.rdd.map(model.transform) + } + + def findSynonyms(word: String, num: Int): JList[Object] = { + val vec = transform(word) + findSynonyms(vec, num) + } + + def findSynonyms(vector: Vector, num: Int): JList[Object] = { + val result = model.findSynonyms(vector, num) + val similarity = Vectors.dense(result.map(_._2)) + val words = result.map(_._1) + List(words, similarity).map(_.asInstanceOf[Object]).asJava + } } /** @@ -293,13 +457,13 @@ class PythonMLLibAPI extends Serializable { * Extra care needs to be taken in the Python code to ensure it gets freed on exit; * see the Py4J documentation. * @param data Training data - * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + * @param categoricalFeaturesInfo Categorical features info, as Java map */ def trainDecisionTreeModel( data: JavaRDD[LabeledPoint], algoStr: String, numClasses: Int, - categoricalFeaturesInfoJMap: java.util.Map[Int, Int], + categoricalFeaturesInfo: JMap[Int, Int], impurityStr: String, maxDepth: Int, maxBins: Int, @@ -315,11 +479,53 @@ class PythonMLLibAPI extends Serializable { maxDepth = maxDepth, numClassesForClassification = numClasses, maxBins = maxBins, - categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap, + categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap, minInstancesPerNode = minInstancesPerNode, minInfoGain = minInfoGain) + try { + DecisionTree.train(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), strategy) + } finally { + data.rdd.unpersist(blocking = false) + } + } - DecisionTree.train(data.rdd, strategy) + /** + * Java stub for Python mllib RandomForest.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + */ + def trainRandomForestModel( + data: JavaRDD[LabeledPoint], + algoStr: String, + numClasses: Int, + categoricalFeaturesInfo: JMap[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurityStr: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + + val algo = Algo.fromString(algoStr) + val impurity = Impurities.fromString(impurityStr) + val strategy = new Strategy( + algo = algo, + impurity = impurity, + maxDepth = maxDepth, + numClassesForClassification = numClasses, + maxBins = maxBins, + categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap) + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + if (algo == Algo.Classification) { + RandomForest.trainClassifier(cached, strategy, numTrees, featureSubsetStrategy, seed) + } else { + RandomForest.trainRegressor(cached, strategy, numTrees, featureSubsetStrategy, seed) + } + } finally { + cached.unpersist(blocking = false) + } } /** @@ -346,6 +552,31 @@ class PythonMLLibAPI extends Serializable { Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method)) } + /** + * Java stub for mllib Statistics.chiSqTest() + */ + def chiSqTest(observed: Vector, expected: Vector): ChiSqTestResult = { + if (expected == null) { + Statistics.chiSqTest(observed) + } else { + Statistics.chiSqTest(observed, expected) + } + } + + /** + * Java stub for mllib Statistics.chiSqTest(observed: Matrix) + */ + def chiSqTest(observed: Matrix): ChiSqTestResult = { + Statistics.chiSqTest(observed) + } + + /** + * Java stub for mllib Statistics.chiSqTest(RDD[LabelPoint]) + */ + def chiSqTest(data: JavaRDD[LabeledPoint]): Array[ChiSqTestResult] = { + Statistics.chiSqTest(data.rdd) + } + // used by the corr methods to retrieve the name of the correlation method passed in via pyspark private def getCorrNameOrDefault(method: String) = { if (method == null) CorrelationNames.defaultCorrName else method @@ -454,6 +685,7 @@ class PythonMLLibAPI extends Serializable { private[spark] object SerDe extends Serializable { val PYSPARK_PACKAGE = "pyspark.mllib" + val LATIN1 = "ISO-8859-1" /** * Base class used for pickle @@ -475,7 +707,7 @@ private[spark] object SerDe extends Serializable { def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = { if (obj == this) { out.write(Opcodes.GLOBAL) - out.write((module + "\n" + name + "\n").getBytes()) + out.write((module + "\n" + name + "\n").getBytes) } else { pickler.save(this) // it will be memorized by Pickler saveState(obj, out, pickler) @@ -487,7 +719,7 @@ private[spark] object SerDe extends Serializable { if (objects.length == 0 || objects.length > 3) { out.write(Opcodes.MARK) } - objects.foreach(pickler.save(_)) + objects.foreach(pickler.save) val code = objects.length match { case 1 => Opcodes.TUPLE1 case 2 => Opcodes.TUPLE2 @@ -505,7 +737,16 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { val vector: DenseVector = obj.asInstanceOf[DenseVector] - saveObjects(out, pickler, vector.toArray) + val bytes = new Array[Byte](8 * vector.size) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + val db = bb.asDoubleBuffer() + db.put(vector.values) + + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(bytes.length)) + out.write(bytes) + out.write(Opcodes.TUPLE1) } def construct(args: Array[Object]): Object = { @@ -513,7 +754,13 @@ private[spark] object SerDe extends Serializable { if (args.length != 1) { throw new PickleException("should be 1") } - new DenseVector(args(0).asInstanceOf[Array[Double]]) + val bytes = args(0).asInstanceOf[String].getBytes(LATIN1) + val bb = ByteBuffer.wrap(bytes, 0, bytes.length) + bb.order(ByteOrder.nativeOrder()) + val db = bb.asDoubleBuffer() + val ans = new Array[Double](bytes.length / 8) + db.get(ans) + Vectors.dense(ans) } } @@ -522,15 +769,30 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { val m: DenseMatrix = obj.asInstanceOf[DenseMatrix] - saveObjects(out, pickler, m.numRows, m.numCols, m.values) + val bytes = new Array[Byte](8 * m.values.size) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values) + + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(m.numRows)) + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(m.numCols)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(bytes.length)) + out.write(bytes) + out.write(Opcodes.TUPLE3) } def construct(args: Array[Object]): Object = { if (args.length != 3) { throw new PickleException("should be 3") } - new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], - args(2).asInstanceOf[Array[Double]]) + val bytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val n = bytes.length / 8 + val values = new Array[Double](n) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values) + new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values) } } @@ -539,15 +801,40 @@ private[spark] object SerDe extends Serializable { def saveState(obj: Object, out: OutputStream, pickler: Pickler) = { val v: SparseVector = obj.asInstanceOf[SparseVector] - saveObjects(out, pickler, v.size, v.indices, v.values) + val n = v.indices.size + val indiceBytes = new Array[Byte](4 * n) + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices) + val valueBytes = new Array[Byte](8 * n) + ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values) + + out.write(Opcodes.BININT) + out.write(PickleUtils.integer_to_bytes(v.size)) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(indiceBytes.length)) + out.write(indiceBytes) + out.write(Opcodes.BINSTRING) + out.write(PickleUtils.integer_to_bytes(valueBytes.length)) + out.write(valueBytes) + out.write(Opcodes.TUPLE3) } def construct(args: Array[Object]): Object = { if (args.length != 3) { throw new PickleException("should be 3") } - new SparseVector(args(0).asInstanceOf[Int], args(1).asInstanceOf[Array[Int]], - args(2).asInstanceOf[Array[Double]]) + val size = args(0).asInstanceOf[Int] + val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1) + val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1) + val n = indiceBytes.length / 4 + val indices = new Array[Int](n) + val values = new Array[Double](n) + if (n > 0) { + val order = ByteOrder.nativeOrder() + ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices) + ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values) + } + new SparseVector(size, indices, values) } } @@ -584,13 +871,24 @@ private[spark] object SerDe extends Serializable { } } + var initialized = false + // This should be called before trying to serialize any above classes + // In cluster mode, this should be put in the closure def initialize(): Unit = { - new DenseVectorPickler().register() - new DenseMatrixPickler().register() - new SparseVectorPickler().register() - new LabeledPointPickler().register() - new RatingPickler().register() + SerDeUtil.initialize() + synchronized { + if (!initialized) { + new DenseVectorPickler().register() + new DenseMatrixPickler().register() + new SparseVectorPickler().register() + new LabeledPointPickler().register() + new RatingPickler().register() + initialized = true + } + } } + // will not called in Executor automatically + initialize() def dumps(obj: AnyRef): Array[Byte] = { new Pickler().dumps(obj) @@ -604,4 +902,38 @@ private[spark] object SerDe extends Serializable { def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = { rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int])) } + + /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ + def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = { + rdd.map(x => Array(x._1, x._2)) + } + + /** + * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by + * PySpark. + */ + def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = { + jRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + new SerDeUtil.AutoBatchedPickler(iter) + } + } + + /** + * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark. + */ + def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = { + pyRDD.rdd.mapPartitions { iter => + initialize() // let it called in executor + val unpickle = new Unpickler + iter.flatMap { row => + val obj = unpickle.loads(row) + if (batched) { + obj.asInstanceOf[JArrayList[_]].asScala + } else { + Seq(obj) + } + } + }.toJavaRDD() + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala index 87bdc8558aaf5..c67a6d3ae6cce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.api /** - * Internal support for MLLib Python API. + * Internal support for MLlib Python API. * * @see [[org.apache.spark.mllib.api.python.PythonMLLibAPI]] */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 84d3c7cebd7c8..94d757bc317ab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -64,16 +64,17 @@ class LogisticRegressionModel ( val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept val score = 1.0 / (1.0 + math.exp(-margin)) threshold match { - case Some(t) => if (score < t) 0.0 else 1.0 + case Some(t) => if (score > t) 1.0 else 0.0 case None => score } } } /** - * Train a classification model for Logistic Regression using Stochastic Gradient Descent. - * NOTE: Labels used in Logistic Regression should be {0, 1} - * + * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By + * default L2 regularization is used, which can be changed via + * [[LogisticRegressionWithSGD.optimizer]]. + * NOTE: Labels used in Logistic Regression should be {0, 1}. * Using [[LogisticRegressionWithLBFGS]] is recommended over this. */ class LogisticRegressionWithSGD private ( @@ -93,9 +94,10 @@ class LogisticRegressionWithSGD private ( override protected val validators = List(DataValidators.binaryLabelValidator) /** - * Construct a LogisticRegression object with default parameters + * Construct a LogisticRegression object with default parameters: {stepSize: 1.0, + * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 0.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 80f8a1b2f1e84..dd514ff8a37f2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -65,14 +65,15 @@ class SVMModel ( intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept threshold match { - case Some(t) => if (margin < t) 0.0 else 1.0 + case Some(t) => if (margin > t) 1.0 else 0.0 case None => margin } } } /** - * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. + * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2 + * regularization is used, which can be changed via [[SVMWithSGD.optimizer]]. * NOTE: Labels used in SVM should be {0, 1}. */ class SVMWithSGD private ( @@ -92,9 +93,10 @@ class SVMWithSGD private ( override protected val validators = List(DataValidators.binaryLabelValidator) /** - * Construct a SVM object with default parameters + * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100, + * regParm: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new SVMModel(weights, intercept) @@ -185,6 +187,6 @@ object SVMWithSGD { * @return a SVMModel which has the weights and offset from training. */ def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 7443f232ec3e7..34ea0de706f08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -113,22 +113,13 @@ class KMeans private ( this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ def run(data: RDD[Vector]): KMeansModel = { - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -143,7 +134,7 @@ class KMeans private ( norms.unpersist() // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + if (data.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala new file mode 100644 index 0000000000000..6189dce9b27da --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -0,0 +1,268 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.StreamingContext._ +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom + +/** + * :: DeveloperApi :: + * StreamingKMeansModel extends MLlib's KMeansModel for streaming + * algorithms, so it can keep track of a continuously updated weight + * associated with each cluster, and also update the model by + * doing a single iteration of the standard k-means algorithm. + * + * The update algorithm uses the "mini-batch" KMeans rule, + * generalized to incorporate forgetfullness (i.e. decay). + * The update rule (for each cluster) is: + * + * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t] + * n_t+t = n_t * a + m_t + * + * Where c_t is the previously estimated centroid for that cluster, + * n_t is the number of points assigned to it thus far, x_t is the centroid + * estimated on the current batch, and m_t is the number of points assigned + * to that centroid in the current batch. + * + * The decay factor 'a' scales the contribution of the clusters as estimated thus far, + * by applying a as a discount weighting on the current point when evaluating + * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids + * are determined entirely by recent data. Lower values correspond to + * more forgetting. + * + * Decay can optionally be specified by a half life and associated + * time unit. The time unit can either be a batch of data or a single + * data point. Considering data arrived at time t, the half life h is defined + * such that at time t + h the discount applied to the data from t is 0.5. + * The definition remains the same whether the time unit is given + * as batches or points. + * + */ +@DeveloperApi +class StreamingKMeansModel( + override val clusterCenters: Array[Vector], + val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging { + + /** Perform a k-means update on a batch of data. */ + def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = { + + // find nearest cluster to each point + val closest = data.map(point => (this.predict(point), (point, 1L))) + + // get sums and counts for updating each cluster + val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => { + BLAS.axpy(1.0, p2._1, p1._1) + (p1._1, p1._2 + p2._2) + } + val dim = clusterCenters(0).size + val pointStats: Array[(Int, (Vector, Long))] = closest + .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs) + .collect() + + val discount = timeUnit match { + case StreamingKMeans.BATCHES => decayFactor + case StreamingKMeans.POINTS => + val numNewPoints = pointStats.view.map { case (_, (_, n)) => + n + }.sum + math.pow(decayFactor, numNewPoints) + } + + // apply discount to weights + BLAS.scal(discount, Vectors.dense(clusterWeights)) + + // implement update rule + pointStats.foreach { case (label, (sum, count)) => + val centroid = clusterCenters(label) + + val updatedWeight = clusterWeights(label) + count + val lambda = count / math.max(updatedWeight, 1e-16) + + clusterWeights(label) = updatedWeight + BLAS.scal(1.0 - lambda, centroid) + BLAS.axpy(lambda / count, sum, centroid) + + // display the updated cluster centers + val display = clusterCenters(label).size match { + case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...") + case _ => centroid.toArray.mkString("[", ",", "]") + } + + logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display") + } + + // Check whether the smallest cluster is dying. If so, split the largest cluster. + val weightsWithIndex = clusterWeights.view.zipWithIndex + val (maxWeight, largest) = weightsWithIndex.maxBy(_._1) + val (minWeight, smallest) = weightsWithIndex.minBy(_._1) + if (minWeight < 1e-8 * maxWeight) { + logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.") + val weight = (maxWeight + minWeight) / 2.0 + clusterWeights(largest) = weight + clusterWeights(smallest) = weight + val largestClusterCenter = clusterCenters(largest) + val smallestClusterCenter = clusterCenters(smallest) + var j = 0 + while (j < dim) { + val x = largestClusterCenter(j) + val p = 1e-14 * math.max(math.abs(x), 1.0) + largestClusterCenter.toBreeze(j) = x + p + smallestClusterCenter.toBreeze(j) = x - p + j += 1 + } + } + + this + } +} + +/** + * :: DeveloperApi :: + * StreamingKMeans provides methods for configuring a + * streaming k-means analysis, training the model on streaming, + * and using the model to make predictions on streaming data. + * See KMeansModel for details on algorithm and update rules. + * + * Use a builder pattern to construct a streaming k-means analysis + * in an application, like: + * + * val model = new StreamingKMeans() + * .setDecayFactor(0.5) + * .setK(3) + * .setRandomCenters(5, 100.0) + * .trainOn(DStream) + */ +@DeveloperApi +class StreamingKMeans( + var k: Int, + var decayFactor: Double, + var timeUnit: String) extends Logging { + + def this() = this(2, 1.0, StreamingKMeans.BATCHES) + + protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null) + + /** Set the number of clusters. */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** Set the decay factor directly (for forgetful algorithms). */ + def setDecayFactor(a: Double): this.type = { + this.decayFactor = decayFactor + this + } + + /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */ + def setHalfLife(halfLife: Double, timeUnit: String): this.type = { + if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) { + throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) + } + this.decayFactor = math.exp(math.log(0.5) / halfLife) + logInfo("Setting decay factor to: %g ".format (this.decayFactor)) + this.timeUnit = timeUnit + this + } + + /** Specify initial centers directly. */ + def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = { + model = new StreamingKMeansModel(centers, weights) + this + } + + /** + * Initialize random centers, requiring only the number of dimensions. + * + * @param dim Number of dimensions + * @param weight Weight for each center + * @param seed Random seed + */ + def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = { + val random = new XORShiftRandom(seed) + val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian()))) + val weights = Array.fill(k)(weight) + model = new StreamingKMeansModel(centers, weights) + this + } + + /** Return the latest model. */ + def latestModel(): StreamingKMeansModel = { + model + } + + /** + * Update the clustering model by training on batches of data from a DStream. + * This operation registers a DStream for training the model, + * checks whether the cluster centers have been initialized, + * and updates the model using each batch of data from the stream. + * + * @param data DStream containing vector data + */ + def trainOn(data: DStream[Vector]) { + assertInitialized() + data.foreachRDD { (rdd, time) => + model = model.update(rdd, decayFactor, timeUnit) + } + } + + /** + * Use the clustering model to make predictions on batches of data from a DStream. + * + * @param data DStream containing vector data + * @return DStream containing predictions + */ + def predictOn(data: DStream[Vector]): DStream[Int] = { + assertInitialized() + data.map(model.predict) + } + + /** + * Use the model to make predictions on the values of a DStream and carry over its keys. + * + * @param data DStream containing (key, feature vector) pairs + * @tparam K key type + * @return DStream containing the input keys and the predictions as values + */ + def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = { + assertInitialized() + data.mapValues(model.predict) + } + + /** Check whether cluster centers have been initialized. */ + private[this] def assertInitialized(): Unit = { + if (model.clusterCenters == null) { + throw new IllegalStateException( + "Initial cluster centers must be set before starting predictions") + } + } +} + +private[clustering] object StreamingKMeans { + final val BATCHES = "batches" + final val POINTS = "points" +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 7858ec602483f..078fbfbe4f0e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -43,7 +43,7 @@ private[evaluation] object AreaUnderCurve { */ def of(curve: RDD[(Double, Double)]): Double = { curve.sliding(2).aggregate(0.0)( - seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + seqOp = (auc: Double, points: Array[(Double, Double)]) => auc + trapezoid(points), combOp = _ + _ ) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala new file mode 100644 index 0000000000000..ea10bde5fa252 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ + +/** + * Evaluator for multilabel classification. + * @param predictionAndLabels an RDD of (predictions, labels) pairs, + * both are non-null Arrays, each with unique elements. + */ +class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { + + private lazy val numDocs: Long = predictionAndLabels.count() + + private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) => + labels}.distinct().count() + + /** + * Returns subset accuracy + * (for equal sets of labels) + */ + lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) => + predictions.deep == labels.deep + }.count().toDouble / numDocs + + /** + * Returns accuracy + */ + lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) => + labels.intersect(predictions).size.toDouble / + (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs + + + /** + * Returns Hamming-loss + */ + lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) => + labels.size + predictions.size - 2 * labels.intersect(predictions).size + }.sum / (numDocs * numLabels) + + /** + * Returns document-based precision averaged by the number of documents + */ + lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) => + if (predictions.size > 0) { + predictions.intersect(labels).size.toDouble / predictions.size + } else { + 0 + } + }.sum / numDocs + + /** + * Returns document-based recall averaged by the number of documents + */ + lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) => + labels.intersect(predictions).size.toDouble / labels.size + }.sum / numDocs + + /** + * Returns document-based f1-measure averaged by the number of documents + */ + lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) => + 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size) + }.sum / numDocs + + private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) => + predictions.intersect(labels) + }.countByValue() + + private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) => + predictions.diff(labels) + }.countByValue() + + private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) => + labels.diff(predictions) + }.countByValue() + + /** + * Returns precision for a given label (category) + * @param label the label. + */ + def precision(label: Double) = { + val tp = tpPerClass(label) + val fp = fpPerClass.getOrElse(label, 0L) + if (tp + fp == 0) 0 else tp.toDouble / (tp + fp) + } + + /** + * Returns recall for a given label (category) + * @param label the label. + */ + def recall(label: Double) = { + val tp = tpPerClass(label) + val fn = fnPerClass.getOrElse(label, 0L) + if (tp + fn == 0) 0 else tp.toDouble / (tp + fn) + } + + /** + * Returns f1-measure for a given label (category) + * @param label the label. + */ + def f1Measure(label: Double) = { + val p = precision(label) + val r = recall(label) + if((p + r) == 0) 0 else 2 * p * r / (p + r) + } + + private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp } + private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp } + private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn } + + /** + * Returns micro-averaged label-based precision + * (equals to micro-averaged document-based precision) + */ + lazy val microPrecision = { + val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp} + sumTp.toDouble / (sumTp + sumFp) + } + + /** + * Returns micro-averaged label-based recall + * (equals to micro-averaged document-based recall) + */ + lazy val microRecall = { + val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn} + sumTp.toDouble / (sumTp + sumFn) + } + + /** + * Returns micro-averaged label-based f1-measure + * (equals to micro-averaged document-based f1-measure) + */ + lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass) + + /** + * Returns the sequence of labels in ascending order + */ + lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala new file mode 100644 index 0000000000000..93a7353e2c070 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import scala.reflect.ClassTag + +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD + +/** + * ::Experimental:: + * Evaluator for ranking algorithms. + * + * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs. + */ +@Experimental +class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]) + extends Logging with Serializable { + + /** + * Compute the average precision of all the queries, truncated at ranking position k. + * + * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be + * computed as #(relevant items retrieved) / k. This formula also applies when the size of the + * ground truth set is less than k. + * + * If a query has an empty ground truth set, zero will be used as precision together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated precision, must be positive + * @return the average precision at the first k ranking positions + */ + def precisionAt(k: Int): Double = { + require(k > 0, "ranking position k should be positive") + predictionAndLabels.map { case (pred, lab) => + val labSet = lab.toSet + + if (labSet.nonEmpty) { + val n = math.min(pred.length, k) + var i = 0 + var cnt = 0 + while (i < n) { + if (labSet.contains(pred(i))) { + cnt += 1 + } + i += 1 + } + cnt.toDouble / k + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }.mean + } + + /** + * Returns the mean average precision (MAP) of all the queries. + * If a query has an empty ground truth set, the average precision will be zero and a log + * warining is generated. + */ + lazy val meanAveragePrecision: Double = { + predictionAndLabels.map { case (pred, lab) => + val labSet = lab.toSet + + if (labSet.nonEmpty) { + var i = 0 + var cnt = 0 + var precSum = 0.0 + val n = pred.length + while (i < n) { + if (labSet.contains(pred(i))) { + cnt += 1 + precSum += cnt.toDouble / (i + 1) + } + i += 1 + } + precSum / labSet.size + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }.mean + } + + /** + * Compute the average NDCG value of all the queries, truncated at ranking position k. + * The discounted cumulative gain at position k is computed as: + * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1), + * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current + * implementation, the relevance value is binary. + + * If a query has an empty ground truth set, zero will be used as ndcg together with + * a log warning. + * + * See the following paper for detail: + * + * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen + * + * @param k the position to compute the truncated ndcg, must be positive + * @return the average ndcg at the first k ranking positions + */ + def ndcgAt(k: Int): Double = { + require(k > 0, "ranking position k should be positive") + predictionAndLabels.map { case (pred, lab) => + val labSet = lab.toSet + + if (labSet.nonEmpty) { + val labSetSize = labSet.size + val n = math.min(math.max(pred.length, labSetSize), k) + var maxDcg = 0.0 + var dcg = 0.0 + var i = 0 + while (i < n) { + val gain = 1.0 / math.log(i + 2) + if (labSet.contains(pred(i))) { + dcg += gain + } + if (i < labSetSize) { + maxDcg += gain + } + i += 1 + } + dcg / maxDcg + } else { + logWarning("Empty ground truth set, check input data") + 0.0 + } + }.mean + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala new file mode 100644 index 0000000000000..693117d820580 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} + +/** + * :: Experimental :: + * Evaluator for regression. + * + * @param predictionAndObservations an RDD of (prediction, observation) pairs. + */ +@Experimental +class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging { + + /** + * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors. + */ + private lazy val summary: MultivariateStatisticalSummary = { + val summary: MultivariateStatisticalSummary = predictionAndObservations.map { + case (prediction, observation) => Vectors.dense(observation, observation - prediction) + }.aggregate(new MultivariateOnlineSummarizer())( + (summary, v) => summary.add(v), + (sum1, sum2) => sum1.merge(sum2) + ) + summary + } + + /** + * Returns the explained variance regression score. + * explainedVariance = 1 - variance(y - \hat{y}) / variance(y) + * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] + */ + def explainedVariance: Double = { + 1 - summary.variance(1) / summary.variance(0) + } + + /** + * Returns the mean absolute error, which is a risk function corresponding to the + * expected value of the absolute error loss or l1-norm loss. + */ + def meanAbsoluteError: Double = { + summary.normL1(1) / summary.count + } + + /** + * Returns the mean squared error, which is a risk function corresponding to the + * expected value of the squared error loss or quadratic loss. + */ + def meanSquaredError: Double = { + val rmse = summary.normL2(1) / math.sqrt(summary.count) + rmse * rmse + } + + /** + * Returns the root mean squared error, which is defined as the square root of + * the mean squared error. + */ + def rootMeanSquaredError: Double = { + summary.normL2(1) / math.sqrt(summary.count) + } + + /** + * Returns R^2^, the coefficient of determination. + * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + */ + def r2: Double = { + 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala index 562663ad36b40..be3319d60ce25 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -24,26 +24,43 @@ private[evaluation] trait BinaryClassificationMetricComputer extends Serializabl def apply(c: BinaryConfusionMatrix): Double } -/** Precision. */ +/** Precision. Defined as 1.0 when there are no positive examples. */ private[evaluation] object Precision extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives) + override def apply(c: BinaryConfusionMatrix): Double = { + val totalPositives = c.numTruePositives + c.numFalsePositives + if (totalPositives == 0) { + 1.0 + } else { + c.numTruePositives.toDouble / totalPositives + } + } } -/** False positive rate. */ +/** False positive rate. Defined as 0.0 when there are no negative examples. */ private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - c.numFalsePositives.toDouble / c.numNegatives + override def apply(c: BinaryConfusionMatrix): Double = { + if (c.numNegatives == 0) { + 0.0 + } else { + c.numFalsePositives.toDouble / c.numNegatives + } + } } -/** Recall. */ +/** Recall. Defined as 0.0 when there are no positive examples. */ private[evaluation] object Recall extends BinaryClassificationMetricComputer { - override def apply(c: BinaryConfusionMatrix): Double = - c.numTruePositives.toDouble / c.numPositives + override def apply(c: BinaryConfusionMatrix): Double = { + if (c.numPositives == 0) { + 0.0 + } else { + c.numTruePositives.toDouble / c.numPositives + } + } } /** - * F-Measure. + * F-Measure. Defined as 0 if both precision and recall are 0. EG in the case that all examples + * are false positives. * @param beta the beta constant in F-Measure * @see http://en.wikipedia.org/wiki/F1_score */ @@ -52,6 +69,10 @@ private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificati override def apply(c: BinaryConfusionMatrix): Double = { val precision = Precision(c) val recall = Recall(c) - (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall) + if (precision + recall == 0) { + 0.0 + } else { + (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala index 3afb47767281c..a9c2e23717896 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala @@ -17,16 +17,16 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import breeze.linalg.{norm => brzNorm} import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} /** * :: Experimental :: * Normalizes samples individually to unit L^p^ norm * - * For any 1 <= p < Double.PositiveInfinity, normalizes samples using + * For any 1 <= p < Double.PositiveInfinity, normalizes samples using * sum(abs(vector).^p^)^(1/p)^ as norm. * * For p = Double.PositiveInfinity, max(abs(vector)) will be used as norm for normalization. @@ -47,22 +47,31 @@ class Normalizer(p: Double) extends VectorTransformer { * @return normalized vector. If the norm of the input is zero, it will return the input vector. */ override def transform(vector: Vector): Vector = { - var norm = vector.toBreeze.norm(p) + val norm = brzNorm(vector.toBreeze, p) if (norm != 0.0) { // For dense vector, we've to allocate new memory for new output vector. // However, for sparse vector, the `index` array will not be changed, // so we can re-use it to save memory. - vector.toBreeze match { - case dv: BDV[Double] => Vectors.fromBreeze(dv :/ norm) - case sv: BSV[Double] => - val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size var i = 0 - while (i < output.data.length) { - output.data(i) /= norm + while (i < size) { + values(i) /= norm i += 1 } - Vectors.fromBreeze(output) + Vectors.dense(values) + case sv: SparseVector => + val values = sv.values.clone() + val nnz = values.size + var i = 0 + while (i < nnz) { + values(i) /= norm + i += 1 + } + Vectors.sparse(sv.size, sv.indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 4dfd1f0ab8134..8c4c5db5258d5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -17,11 +17,9 @@ package org.apache.spark.mllib.feature -import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} - import org.apache.spark.Logging import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD @@ -77,8 +75,8 @@ class StandardScalerModel private[mllib] ( require(mean.size == variance.size) - private lazy val factor: BDV[Double] = { - val f = BDV.zeros[Double](variance.size) + private lazy val factor: Array[Double] = { + val f = Array.ofDim[Double](variance.size) var i = 0 while (i < f.size) { f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 @@ -87,6 +85,11 @@ class StandardScalerModel private[mllib] ( f } + // Since `shift` will be only used in `withMean` branch, we have it as + // `lazy val` so it will be evaluated in that branch. Note that we don't + // want to create this array multiple times in `transform` function. + private lazy val shift: Array[Double] = mean.toArray + /** * Applies standardization transformation on a vector. * @@ -97,30 +100,57 @@ class StandardScalerModel private[mllib] ( override def transform(vector: Vector): Vector = { require(mean.size == vector.size) if (withMean) { - vector.toBreeze match { - case dv: BDV[Double] => - val output = vector.toBreeze.copy - var i = 0 - while (i < output.length) { - output(i) = (output(i) - mean(i)) * (if (withStd) factor(i) else 1.0) - i += 1 + // By default, Scala generates Java methods for member variables. So every time when + // the member variables are accessed, `invokespecial` will be called which is expensive. + // This can be avoid by having a local reference of `shift`. + val localShift = shift + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size + if (withStd) { + // Having a local reference of `factor` to avoid overhead as the comment before. + val localFactor = factor + var i = 0 + while (i < size) { + values(i) = (values(i) - localShift(i)) * localFactor(i) + i += 1 + } + } else { + var i = 0 + while (i < size) { + values(i) -= localShift(i) + i += 1 + } } - Vectors.fromBreeze(output) + Vectors.dense(values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else if (withStd) { - vector.toBreeze match { - case dv: BDV[Double] => Vectors.fromBreeze(dv :* factor) - case sv: BSV[Double] => + // Having a local reference of `factor` to avoid overhead as the comment before. + val localFactor = factor + vector match { + case dv: DenseVector => + val values = dv.values.clone() + val size = values.size + var i = 0 + while(i < size) { + values(i) *= localFactor(i) + i += 1 + } + Vectors.dense(values) + case sv: SparseVector => // For sparse vector, the `index` array inside sparse vector object will not be changed, // so we can re-use it to save memory. - val output = new BSV[Double](sv.index, sv.data.clone(), sv.length) + val indices = sv.indices + val values = sv.values.clone() + val nnz = values.size var i = 0 - while (i < output.data.length) { - output.data(i) *= factor(output.index(i)) + while (i < nnz) { + values(i) *= localFactor(indices(i)) i += 1 } - Vectors.fromBreeze(output) + Vectors.sparse(sv.size, indices, values) case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala index 415a845332d45..7358c1c84f79c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD @@ -48,4 +49,14 @@ trait VectorTransformer extends Serializable { data.map(x => this.transform(x)) } + /** + * Applies transformation on an JavaRDD[Vector]. + * + * @param data JavaRDD[Vector] to be transformed. + * @return transformed JavaRDD[Vector]. + */ + def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = { + transform(data.rdd) + } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index fc1444705364a..7960f3cab576f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -67,7 +67,7 @@ private case class VocabWord( class Word2Vec extends Serializable with Logging { private var vectorSize = 100 - private var startingAlpha = 0.025 + private var learningRate = 0.025 private var numPartitions = 1 private var numIterations = 1 private var seed = Utils.random.nextLong() @@ -84,7 +84,7 @@ class Word2Vec extends Serializable with Logging { * Sets initial learning rate (default: 0.025). */ def setLearningRate(learningRate: Double): this.type = { - this.startingAlpha = learningRate + this.learningRate = learningRate this } @@ -286,7 +286,7 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) - var alpha = startingAlpha + var alpha = learningRate for (k <- 1 to numIterations) { val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) @@ -300,8 +300,8 @@ class Word2Vec extends Serializable with Logging { lwc = wordCount // TODO: discount by iteration? alpha = - startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) - if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 + learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } wc += sentence.size @@ -432,18 +432,18 @@ class Word2VecModel private[mllib] ( throw new IllegalStateException(s"$word not in vocabulary") } } - + /** * Find synonyms of a word * @param word a word * @param num number of synonyms to find - * @return array of (word, similarity) + * @return array of (word, cosineSimilarity) */ def findSynonyms(word: String, num: Int): Array[(String, Double)] = { val vector = transform(word) findSynonyms(vector,num) } - + /** * Find synonyms of the vector representation of a word * @param vector vector representation of a word @@ -461,4 +461,11 @@ class Word2VecModel private[mllib] ( .tail .toArray } + + /** + * Returns a map of words to their vector representations. + */ + def getVectors: Map[String, Array[Float]] = { + model + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 54ee930d61003..89539e600f48c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -25,7 +25,7 @@ import org.apache.spark.Logging /** * BLAS routines for MLlib's vectors and matrices. */ -private[mllib] object BLAS extends Serializable with Logging { +private[spark] object BLAS extends Serializable with Logging { @transient private var _f2jBLAS: NetlibBLAS = _ @transient private var _nativeBLAS: NetlibBLAS = _ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 2cc52e94282ba..327366a1a3a82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -17,12 +17,10 @@ package org.apache.spark.mllib.linalg -import java.util.Arrays +import java.util.{Random, Arrays} import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM} -import org.apache.spark.util.random.XORShiftRandom - /** * Trait for a local matrix. */ @@ -67,14 +65,14 @@ sealed trait Matrix extends Serializable { } /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - def transposeMultiply(y: DenseMatrix): DenseMatrix = { + private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] BLAS.gemm(true, false, 1.0, this, y, 0.0, C) C } /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - def transposeMultiply(y: DenseVector): DenseVector = { + private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { val output = new DenseVector(new Array[Double](numCols)) BLAS.gemv(true, 1.0, this, y, 0.0, output) output @@ -291,22 +289,22 @@ object Matrices { * Generate a `DenseMatrix` consisting of i.i.d. uniform random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in U(0, 1) */ - def rand(numRows: Int, numCols: Int): Matrix = { - val rand = new XORShiftRandom - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextDouble())) + def rand(numRows: Int, numCols: Int, rng: Random): Matrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextDouble())) } /** * Generate a `DenseMatrix` consisting of i.i.d. gaussian random numbers. * @param numRows number of rows of the matrix * @param numCols number of columns of the matrix + * @param rng a random number generator * @return `DenseMatrix` with size `numRows` x `numCols` and values in N(0, 1) */ - def randn(numRows: Int, numCols: Int): Matrix = { - val rand = new XORShiftRandom - new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rand.nextGaussian())) + def randn(numRows: Int, numCols: Int, rng: Random): Matrix = { + new DenseMatrix(numRows, numCols, Array.fill(numRows * numCols)(rng.nextGaussian())) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 6af225b7f49f7..c6d5fe5bc678c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,22 +17,26 @@ package org.apache.spark.mllib.linalg -import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import java.util +import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} import scala.annotation.varargs import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} -import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException +import org.apache.spark.mllib.util.NumericParser +import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.catalyst.types._ /** * Represents a numeric vector, whose index type is Int and value type is Double. * * Note: Users should not implement this interface. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) sealed trait Vector extends Serializable { /** @@ -72,6 +76,77 @@ sealed trait Vector extends Serializable { def copy: Vector = { throw new NotImplementedError(s"copy is not implemented for ${this.getClass}.") } + + /** + * Applies a function `f` to all the active elements of dense and sparse vector. + * + * @param f the function takes two parameters where the first parameter is the index of + * the vector with type `Int`, and the second parameter is the corresponding value + * with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Double) => Unit) +} + +/** + * User-defined type for [[Vector]] which allows easy interaction with SQL + * via [[org.apache.spark.sql.SchemaRDD]]. + */ +private[spark] class VectorUDT extends UserDefinedType[Vector] { + + override def sqlType: StructType = { + // type: 0 = sparse, 1 = dense + // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse + // vectors. The "values" field is nullable because we might want to add binary vectors later, + // which uses "size" and "indices", but not "values". + StructType(Seq( + StructField("type", ByteType, nullable = false), + StructField("size", IntegerType, nullable = true), + StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true))) + } + + override def serialize(obj: Any): Row = { + val row = new GenericMutableRow(4) + obj match { + case sv: SparseVector => + row.setByte(0, 0) + row.setInt(1, sv.size) + row.update(2, sv.indices.toSeq) + row.update(3, sv.values.toSeq) + case dv: DenseVector => + row.setByte(0, 1) + row.setNullAt(1) + row.setNullAt(2) + row.update(3, dv.values.toSeq) + } + row + } + + override def deserialize(datum: Any): Vector = { + datum match { + // TODO: something wrong with UDT serialization + case v: Vector => + v + case row: Row => + require(row.length == 4, + s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") + val tpe = row.getByte(0) + tpe match { + case 0 => + val size = row.getInt(1) + val indices = row.getAs[Iterable[Int]](2).toArray + val values = row.getAs[Iterable[Double]](3).toArray + new SparseVector(size, indices, values) + case 1 => + val values = row.getAs[Iterable[Double]](3).toArray + new DenseVector(values) + } + } + } + + override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT" + + override def userClass: Class[Vector] = classOf[Vector] } /** @@ -171,7 +246,7 @@ object Vectors { private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => - if (v.offset == 0 && v.stride == 1) { + if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) { new DenseVector(v.data) } else { new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one @@ -191,6 +266,7 @@ object Vectors { /** * A dense vector represented by a value array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class DenseVector(val values: Array[Double]) extends Vector { override def size: Int = values.length @@ -206,6 +282,17 @@ class DenseVector(val values: Array[Double]) extends Vector { override def copy: DenseVector = { new DenseVector(values.clone()) } + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localValues = values + + while (i < localValuesSize) { + f(i, localValues(i)) + i += 1 + } + } } /** @@ -215,6 +302,7 @@ class DenseVector(val values: Array[Double]) extends Vector { * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ +@SQLUserDefinedType(udt = classOf[VectorUDT]) class SparseVector( override val size: Int, val indices: Array[Int], @@ -241,4 +329,16 @@ class SparseVector( } private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) + + private[spark] override def foreachActive(f: (Int, Double) => Unit) = { + var i = 0 + val localValuesSize = values.size + val localIndices = indices + val localValues = values + + while (i < localValuesSize) { + f(localIndices(i), localValues(i)) + i += 1 + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 8380058cf9b41..10a515af88802 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -111,7 +111,10 @@ class RowMatrix( */ def computeGramianMatrix(): Matrix = { val n = numCols().toInt - val nt: Int = n * (n + 1) / 2 + checkNumColumns(n) + // Computes n*(n+1)/2, avoiding overflow in the multiplication. + // This succeeds when n <= 65535, which is checked above + val nt: Int = if (n % 2 == 0) ((n / 2) * (n + 1)) else (n * ((n + 1) / 2)) // Compute the upper triangular part of the gram matrix. val GU = rows.treeAggregate(new BDV[Double](new Array[Double](nt)))( @@ -123,6 +126,16 @@ class RowMatrix( RowMatrix.triuToFull(n, GU.data) } + private def checkNumColumns(cols: Int): Unit = { + if (cols > 65535) { + throw new IllegalArgumentException(s"Argument with more than 65535 cols: $cols") + } + if (cols > 10000) { + val mem = cols * cols * 8 + logWarning(s"$cols columns will require at least $mem bytes of memory!") + } + } + /** * Computes singular value decomposition of this matrix. Denote this matrix by A (m x n). This * will compute matrices U, S, V such that A ~= U * S * V', where S contains the leading k @@ -139,7 +152,7 @@ class RowMatrix( * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined * automatically based on the cost: - * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian + * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian * matrix first and then compute its top eigenvalues and eigenvectors locally on the driver. * This requires a single pass with O(n^2^) storage on each executor and on the driver, and * O(n^2^ k) time on the driver. @@ -156,7 +169,8 @@ class RowMatrix( * @note The conditions that decide which method to use internally and the default parameters are * subject to change. * - * @param k number of leading singular values to keep (0 < k <= n). It might return less than k if + * @param k number of leading singular values to keep (0 < k <= n). + * It might return less than k if * there are numerically zero singular values or there are not enough Ritz values * converged before the maximum number of Arnoldi update iterations is reached (in case * that matrix A is ill-conditioned). @@ -179,7 +193,7 @@ class RowMatrix( /** * The actual SVD implementation, visible for testing. * - * @param k number of leading singular values to keep (0 < k <= n) + * @param k number of leading singular values to keep (0 < k <= n) * @param computeU whether to compute U * @param rCond the reciprocal condition number * @param maxIter max number of iterations (if ARPACK is used) @@ -301,12 +315,7 @@ class RowMatrix( */ def computeCovariance(): Matrix = { val n = numCols().toInt - - if (n > 10000) { - val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE - logWarning(s"The number of columns $n is greater than 10000! " + - s"We need at least $mem bytes of memory.") - } + checkNumColumns(n) val (m, mean) = rows.treeAggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index a6912056395d7..0857877951c82 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -160,14 +160,15 @@ object GradientDescent extends Logging { val stochasticLossHistory = new ArrayBuffer[Double](numIterations) val numExamples = data.count() - val miniBatchSize = numExamples * miniBatchFraction // if no data, return initial weights to avoid NaNs if (numExamples == 0) { - - logInfo("GradientDescent.runMiniBatchSGD returning initial weights, no data found") + logWarning("GradientDescent.runMiniBatchSGD returning initial weights, no data found") return (initialWeights, stochasticLossHistory.toArray) + } + if (numExamples * miniBatchFraction < 1) { + logWarning("The miniBatchFraction is too small") } // Initialize weights as a column vector @@ -185,25 +186,31 @@ object GradientDescent extends Logging { val bcWeights = data.context.broadcast(weights) // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) - .treeAggregate((BDV.zeros[Double](n), 0.0))( - seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => - val l = gradient.compute(features, label, bcWeights.value, Vectors.fromBreeze(grad)) - (grad, loss + l) + val (gradientSum, lossSum, miniBatchSize) = data.sample(false, miniBatchFraction, 42 + i) + .treeAggregate((BDV.zeros[Double](n), 0.0, 0L))( + seqOp = (c, v) => { + // c: (grad, loss, count), v: (label, features) + val l = gradient.compute(v._2, v._1, bcWeights.value, Vectors.fromBreeze(c._1)) + (c._1, c._2 + l, c._3 + 1) }, - combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => - (grad1 += grad2, loss1 + loss2) + combOp = (c1, c2) => { + // c: (grad, loss, count) + (c1._1 += c2._1, c1._2 + c2._2, c1._3 + c2._3) }) - /** - * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration - * and regVal is the regularization value computed in the previous iteration as well. - */ - stochasticLossHistory.append(lossSum / miniBatchSize + regVal) - val update = updater.compute( - weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam) - weights = update._1 - regVal = update._2 + if (miniBatchSize > 0) { + /** + * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration + * and regVal is the regularization value computed in the previous iteration as well. + */ + stochasticLossHistory.append(lossSum / miniBatchSize + regVal) + val update = updater.compute( + weights, Vectors.fromBreeze(gradientSum / miniBatchSize.toDouble), stepSize, i, regParam) + weights = update._1 + regVal = update._2 + } else { + logWarning(s"Iteration ($i/$numIterations). The size of sampled batch is zero") + } } logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala index e4b436b023794..fef062e02b6ec 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala @@ -79,7 +79,7 @@ private[mllib] object NNLS { // stopping condition def stop(step: Double, ndir: Double, nx: Double): Boolean = { ((step.isNaN) // NaN - || (step < 1e-6) // too small or negative + || (step < 1e-7) // too small or negative || (step > 1e40) // too small; almost certainly numerical problems || (ndir < 1e-12 * nx) // gradient relatively too small || (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 28179fbc450c0..51f9b8657c640 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -17,8 +17,7 @@ package org.apache.spark.mllib.random -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} @@ -89,12 +88,13 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] { @DeveloperApi class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] { - private var rng = new Poisson(mean, new DRand) + private var rng = new PoissonDistribution(mean) - override def nextValue(): Double = rng.nextDouble() + override def nextValue(): Double = rng.sample() override def setSeed(seed: Long) { - rng = new Poisson(mean, new DRand(seed.toInt)) + rng = new PoissonDistribution(mean) + rng.reseedRandomGenerator(seed) } override def copy(): PoissonGenerator = new PoissonGenerator(mean) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index b5e403bc8c14d..57c0768084e41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD @@ -28,8 +29,8 @@ import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. */ -private[mllib] -class RDDFunctions[T: ClassTag](self: RDD[T]) { +@DeveloperApi +class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { /** * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding @@ -39,10 +40,10 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { * trigger a Spark job if the parent RDD has more than one partitions and the window size is * greater than 1. */ - def sliding(windowSize: Int): RDD[Seq[T]] = { + def sliding(windowSize: Int): RDD[Array[T]] = { require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") if (windowSize == 1) { - self.map(Seq(_)) + self.map(Array(_)) } else { new SlidingRDD[T](self, windowSize) } @@ -112,7 +113,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { } } -private[mllib] +@DeveloperApi object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index dd80782c0f001..35e81fcb3de0d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -45,15 +45,16 @@ class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T] */ private[mllib] class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) - extends RDD[Seq[T]](parent) { + extends RDD[Array[T]](parent) { require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") - override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = { + override def compute(split: Partition, context: TaskContext): Iterator[Array[T]] = { val part = split.asInstanceOf[SlidingRDDPartition[T]] (firstParent[T].iterator(part.prev, context) ++ part.tail) .sliding(windowSize) .withPartial(false) + .map(_.toArray) } override def getPreferredLocations(split: Partition): Seq[String] = diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 84d192db53e26..90ac252226006 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -20,20 +20,20 @@ package org.apache.spark.mllib.recommendation import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.{abs, sqrt} -import scala.util.Random -import scala.util.Sorting +import scala.util.{Random, Sorting} import scala.util.hashing.byteswap32 import org.jblas.{DoubleMatrix, SimpleBlas, Solve} +import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.broadcast.Broadcast -import org.apache.spark.{Logging, HashPartitioner, Partitioner} -import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -import org.apache.spark.mllib.optimization.NNLS /** * Out-link information for a user or product block. This includes the original user/product IDs @@ -325,6 +325,11 @@ class ALS private ( new MatrixFactorizationModel(rank, usersOut, productsOut) } + /** + * Java-friendly version of [[ALS.run]]. + */ + def run(ratings: JavaRDD[Rating]): MatrixFactorizationModel = run(ratings.rdd) + /** * Computes the (`rank x rank`) matrix `YtY`, where `Y` is the (`nui x rank`) matrix of factors * for each user (or product), in a distributed fashion. @@ -741,7 +746,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into - * @param alpha confidence parameter (only applies when immplicitPrefs = true) + * @param alpha confidence parameter * @param seed random seed */ def trainImplicit( @@ -768,7 +773,7 @@ object ALS { * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) * @param blocks level of parallelism to split computation into - * @param alpha confidence parameter (only applies when immplicitPrefs = true) + * @param alpha confidence parameter */ def trainImplicit( ratings: RDD[Rating], @@ -792,6 +797,7 @@ object ALS { * @param rank number of features to use * @param iterations number of iterations of ALS (recommended: 10-20) * @param lambda regularization factor (recommended: 0.01) + * @param alpha confidence parameter */ def trainImplicit(ratings: RDD[Rating], rank: Int, iterations: Int, lambda: Double, alpha: Double) : MatrixFactorizationModel = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 66b58ba770160..ed2f8b41bcae5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,27 +17,49 @@ package org.apache.spark.mllib.recommendation +import java.lang.{Integer => JavaInteger} + import org.jblas.DoubleMatrix -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.java.JavaRDD +import org.apache.spark.Logging +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.api.python.SerDe +import org.apache.spark.storage.StorageLevel /** * Model representing the result of matrix factorization. * + * Note: If you create the model directly using constructor, please be aware that fast prediction + * requires cached user/product features and their associated partitioners. + * * @param rank Rank for the features in this model. * @param userFeatures RDD of tuples where each tuple represents the userId and * the features computed for this user. * @param productFeatures RDD of tuples where each tuple represents the productId * and the features computed for this product. */ -class MatrixFactorizationModel private[mllib] ( +class MatrixFactorizationModel( val rank: Int, val userFeatures: RDD[(Int, Array[Double])], - val productFeatures: RDD[(Int, Array[Double])]) extends Serializable { + val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging { + + require(rank > 0) + validateFeatures("User", userFeatures) + validateFeatures("Product", productFeatures) + + /** Validates factors and warns users if there are performance concerns. */ + private def validateFeatures(name: String, features: RDD[(Int, Array[Double])]): Unit = { + require(features.first()._2.size == rank, + s"$name feature dimension does not match the rank $rank.") + if (features.partitioner.isEmpty) { + logWarning(s"$name factor does not have a partitioner. " + + "Prediction on individual records could be slow.") + } + if (features.getStorageLevel == StorageLevel.NONE) { + logWarning(s"$name factor is not cached. Prediction could be slow.") + } + } + /** Predict the rating of one user for one product. */ def predict(user: Int, product: Int): Double = { val userVector = new DoubleMatrix(userFeatures.lookup(user).head) @@ -65,6 +87,13 @@ class MatrixFactorizationModel private[mllib] ( } } + /** + * Java-friendly version of [[MatrixFactorizationModel.predict]]. + */ + def predict(usersProducts: JavaPairRDD[JavaInteger, JavaInteger]): JavaRDD[Rating] = { + predict(usersProducts.rdd.asInstanceOf[RDD[(Int, Int)]]).toJavaRDD() + } + /** * Recommends products to a user. * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index d0fe4179685ca..0287f04e2c777 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -75,6 +75,8 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double def predict(testData: Vector): Double = { predictPoint(testData, weights, intercept) } + + override def toString() = "(weights=%s, intercept=%s)".format(weights, intercept) } /** @@ -134,15 +136,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] this } - /** Whether a warning should be logged if the input RDD is uncached. */ - private var warnOnUncachedInput = true - - /** Disable warnings about uncached input. */ - private[spark] def disableUncachedWarning(): this.type = { - warnOnUncachedInput = false - this - } - /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -159,7 +152,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data is not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } @@ -239,7 +232,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } // Warn at the end of the run as well, for increased visibility. - if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + if (input.getStorageLevel == StorageLevel.NONE) { logWarning("The input data was not directly cached, which may hurt performance if its" + " parent RDDs are also uncached.") } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 17c753c56681f..2067b36f246b3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.regression +import scala.beans.BeanInfo + import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException @@ -27,6 +29,7 @@ import org.apache.spark.SparkException * @param label Label for this data point. * @param features List of features for this data point. */ +@BeanInfo case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { "(%s,%s)".format(label, features) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index cb0d39e759a9f..f9791c6571782 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -67,9 +67,9 @@ class LassoWithSGD private ( /** * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 1.0, miniBatchFraction: 1.0}. + * regParam: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new LassoModel(weights, intercept) @@ -161,6 +161,6 @@ object LassoWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int): LassoModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index a826deb695ee1..c8cad773f5efb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -68,9 +68,9 @@ class RidgeRegressionWithSGD private ( /** * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, - * regParam: 1.0, miniBatchFraction: 1.0}. + * regParam: 0.01, miniBatchFraction: 1.0}. */ - def this() = this(1.0, 100, 1.0, 1.0) + def this() = this(1.0, 100, 0.01, 1.0) override protected def createModel(weights: Vector, intercept: Double) = { new RidgeRegressionModel(weights, intercept) @@ -143,7 +143,7 @@ object RidgeRegressionWithSGD { numIterations: Int, stepSize: Double, regParam: Double): RidgeRegressionModel = { - train(input, numIterations, stepSize, regParam, 1.0) + train(input, numIterations, stepSize, regParam, 0.01) } /** @@ -158,6 +158,6 @@ object RidgeRegressionWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int): RidgeRegressionModel = { - train(input, numIterations, 1.0, 1.0, 1.0) + train(input, numIterations, 1.0, 0.01, 1.0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 3025d4837cab4..fcc2a148791bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.stat -import breeze.linalg.{DenseVector => BDV} - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.{Vectors, Vector} @@ -40,14 +38,14 @@ import org.apache.spark.mllib.linalg.{Vectors, Vector} class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with Serializable { private var n = 0 - private var currMean: BDV[Double] = _ - private var currM2n: BDV[Double] = _ - private var currM2: BDV[Double] = _ - private var currL1: BDV[Double] = _ + private var currMean: Array[Double] = _ + private var currM2n: Array[Double] = _ + private var currM2: Array[Double] = _ + private var currL1: Array[Double] = _ private var totalCnt: Long = 0 - private var nnz: BDV[Double] = _ - private var currMax: BDV[Double] = _ - private var currMin: BDV[Double] = _ + private var nnz: Array[Double] = _ + private var currMax: Array[Double] = _ + private var currMin: Array[Double] = _ /** * Add a new sample to this summarizer, and update the statistical summary. @@ -60,35 +58,36 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(sample.size > 0, s"Vector should have dimension larger than zero.") n = sample.size - currMean = BDV.zeros[Double](n) - currM2n = BDV.zeros[Double](n) - currM2 = BDV.zeros[Double](n) - currL1 = BDV.zeros[Double](n) - nnz = BDV.zeros[Double](n) - currMax = BDV.fill(n)(Double.MinValue) - currMin = BDV.fill(n)(Double.MaxValue) + currMean = Array.ofDim[Double](n) + currM2n = Array.ofDim[Double](n) + currM2 = Array.ofDim[Double](n) + currL1 = Array.ofDim[Double](n) + nnz = Array.ofDim[Double](n) + currMax = Array.fill[Double](n)(Double.MinValue) + currMin = Array.fill[Double](n)(Double.MaxValue) } require(n == sample.size, s"Dimensions mismatch when adding new sample." + s" Expecting $n but got ${sample.size}.") - sample.toBreeze.activeIterator.foreach { - case (_, 0.0) => // Skip explicit zero elements. - case (i, value) => - if (currMax(i) < value) { - currMax(i) = value + sample.foreachActive { (index, value) => + if (value != 0.0) { + if (currMax(index) < value) { + currMax(index) = value } - if (currMin(i) > value) { - currMin(i) = value + if (currMin(index) > value) { + currMin(index) = value } - val tmpPrevMean = currMean(i) - currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) - currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) - currM2(i) += value * value - currL1(i) += math.abs(value) + val prevMean = currMean(index) + val diff = value - prevMean + currMean(index) = prevMean + diff / (nnz(index) + 1.0) + currM2n(index) += (value - currMean(index)) * diff + currM2(index) += value * value + currL1(index) += math.abs(value) - nnz(i) += 1.0 + nnz(index) += 1.0 + } } totalCnt += 1 @@ -107,47 +106,38 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt - val deltaMean: BDV[Double] = currMean - other.currMean var i = 0 while (i < n) { - // merge mean together - if (other.currMean(i) != 0.0) { - currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / - (nnz(i) + other.nnz(i)) - } - // merge m2n together - if (nnz(i) + other.nnz(i) != 0.0) { - currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / - (nnz(i) + other.nnz(i)) - } - // merge m2 together - if (nnz(i) + other.nnz(i) != 0.0) { + val thisNnz = nnz(i) + val otherNnz = other.nnz(i) + val totalNnz = thisNnz + otherNnz + if (totalNnz != 0.0) { + val deltaMean = other.currMean(i) - currMean(i) + // merge mean together + currMean(i) += deltaMean * otherNnz / totalNnz + // merge m2n together + currM2n(i) += other.currM2n(i) + deltaMean * deltaMean * thisNnz * otherNnz / totalNnz + // merge m2 together currM2(i) += other.currM2(i) - } - // merge l1 together - if (nnz(i) + other.nnz(i) != 0.0) { + // merge l1 together currL1(i) += other.currL1(i) + // merge max and min + currMax(i) = math.max(currMax(i), other.currMax(i)) + currMin(i) = math.min(currMin(i), other.currMin(i)) } - - if (currMax(i) < other.currMax(i)) { - currMax(i) = other.currMax(i) - } - if (currMin(i) > other.currMin(i)) { - currMin(i) = other.currMin(i) - } + nnz(i) = totalNnz i += 1 } - nnz += other.nnz } else if (totalCnt == 0 && other.totalCnt != 0) { this.n = other.n - this.currMean = other.currMean.copy - this.currM2n = other.currM2n.copy - this.currM2 = other.currM2.copy - this.currL1 = other.currL1.copy + this.currMean = other.currMean.clone + this.currM2n = other.currM2n.clone + this.currM2 = other.currM2.clone + this.currL1 = other.currL1.clone this.totalCnt = other.totalCnt - this.nnz = other.nnz.copy - this.currMax = other.currMax.copy - this.currMin = other.currMin.copy + this.nnz = other.nnz.clone + this.currMax = other.currMax.clone + this.currMin = other.currMin.clone } this } @@ -155,19 +145,19 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def mean: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMean = BDV.zeros[Double](n) + val realMean = Array.ofDim[Double](n) var i = 0 while (i < n) { realMean(i) = currMean(i) * (nnz(i) / totalCnt) i += 1 } - Vectors.fromBreeze(realMean) + Vectors.dense(realMean) } override def variance: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realVariance = BDV.zeros[Double](n) + val realVariance = Array.ofDim[Double](n) val denominator = totalCnt - 1.0 @@ -182,8 +172,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S i += 1 } } - - Vectors.fromBreeze(realVariance) + Vectors.dense(realVariance) } override def count: Long = totalCnt @@ -191,7 +180,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S override def numNonzeros: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(nnz) + Vectors.dense(nnz) } override def max: Vector = { @@ -202,7 +191,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMax) + Vectors.dense(currMax) } override def min: Vector = { @@ -213,25 +202,25 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 i += 1 } - Vectors.fromBreeze(currMin) + Vectors.dense(currMin) } override def normL2: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - val realMagnitude = BDV.zeros[Double](n) + val realMagnitude = Array.ofDim[Double](n) var i = 0 while (i < currM2.size) { realMagnitude(i) = math.sqrt(currM2(i)) i += 1 } - - Vectors.fromBreeze(realMagnitude) + Vectors.dense(realMagnitude) } override def normL1: Vector = { require(totalCnt > 0, s"Nothing has been added to this summarizer.") - Vectors.fromBreeze(currL1) + + Vectors.dense(currL1) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 0089419c2c5d4..ea82d39b72c03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.stat.test import breeze.linalg.{DenseMatrix => BDM} -import cern.jet.stat.Probability.chiSquareComplemented +import org.apache.commons.math3.distribution.ChiSquaredDistribution import org.apache.spark.{SparkException, Logging} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} @@ -33,7 +33,7 @@ import scala.collection.mutable * on an input of type `Matrix` in which independence between columns is assessed. * We also provide a method for computing the chi-squared statistic between each feature and the * label for an input `RDD[LabeledPoint]`, return an `Array[ChiSquaredTestResult]` of size = - * number of features in the inpuy RDD. + * number of features in the input RDD. * * Supported methods for goodness of fit: `pearson` (default) * Supported methods for independence: `pearson` (default) @@ -139,7 +139,7 @@ private[stat] object ChiSqTest extends Logging { } /* - * Pearon's goodness of fit test on the input observed and expected counts/relative frequencies. + * Pearson's goodness of fit test on the input observed and expected counts/relative frequencies. * Uniform distribution is assumed when `expected` is not passed in. */ def chiSquared(observed: Vector, @@ -188,12 +188,12 @@ private[stat] object ChiSqTest extends Logging { } } val df = size - 1 - val pValue = chiSquareComplemented(df, statistic) + val pValue = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(statistic) new ChiSqTestResult(pValue, df, statistic, PEARSON.name, NullHypothesis.goodnessOfFit.toString) } /* - * Pearon's independence test on the input contingency matrix. + * Pearson's independence test on the input contingency matrix. * TODO: optimize for SparseMatrix when it becomes supported. */ def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = { @@ -238,7 +238,13 @@ private[stat] object ChiSqTest extends Logging { j += 1 } val df = (numCols - 1) * (numRows - 1) - val pValue = chiSquareComplemented(df, statistic) - new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString) + if (df == 0) { + // 1 column or 1 row. Constant distribution is independent of anything. + // pValue = 1.0 and statistic = 0.0 in this case. + new ChiSqTestResult(1.0, 0, 0.0, methodName, NullHypothesis.independence.toString) + } else { + val pValue = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(statistic) + new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString) + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index b311d10023894..3d91867c896d9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD @@ -56,13 +58,19 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @return DecisionTreeModel that can be used for prediction */ - def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + def run(input: RDD[LabeledPoint]): DecisionTreeModel = { // Note: random seed will not be used since numTrees = 1. val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) - val rfModel = rf.train(input) + val rfModel = rf.run(input) rfModel.trees(0) } + /** + * Trains a decision tree model over an RDD. This is deprecated because it hides the static + * methods with the same name in Java. + */ + @deprecated("Please use DecisionTree.run instead.", "1.2.0") + def train(input: RDD[LabeledPoint]): DecisionTreeModel = run(input) } object DecisionTree extends Serializable with Logging { @@ -84,7 +92,7 @@ object DecisionTree extends Serializable with Logging { * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -110,7 +118,7 @@ object DecisionTree extends Serializable with Logging { impurity: Impurity, maxDepth: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -138,7 +146,7 @@ object DecisionTree extends Serializable with Logging { maxDepth: Int, numClassesForClassification: Int): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -175,7 +183,7 @@ object DecisionTree extends Serializable with Logging { categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, quantileCalculationStrategy, categoricalFeaturesInfo) - new DecisionTree(strategy).train(input) + new DecisionTree(strategy).run(input) } /** @@ -435,6 +443,11 @@ object DecisionTree extends Serializable with Logging { * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). * Updated with new non-leaf nodes which are created. + * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where + * each value in the array is the data point's node Id + * for a corresponding tree. This is used to prevent the need + * to pass the entire tree to the executors during + * the node stat aggregation phase. */ private[tree] def findBestSplits( input: RDD[BaggedPoint[TreePoint]], @@ -445,7 +458,8 @@ object DecisionTree extends Serializable with Logging { splits: Array[Array[Split]], bins: Array[Array[Bin]], nodeQueue: mutable.Queue[(Int, Node)], - timer: TimeTracker = new TimeTracker): Unit = { + timer: TimeTracker = new TimeTracker, + nodeIdCache: Option[NodeIdCache] = None): Unit = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -477,6 +491,37 @@ object DecisionTree extends Serializable with Logging { logDebug("isMulticlass = " + metadata.isMulticlass) logDebug("isMulticlassWithCategoricalFeatures = " + metadata.isMulticlassWithCategoricalFeatures) + logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) + + /** + * Performs a sequential aggregation over a partition for a particular tree and node. + * + * For each feature, the aggregate sufficient statistics are updated for the relevant + * bins. + * + * @param treeIndex Index of the tree that we want to perform aggregation for. + * @param nodeInfo The node info for the tree node. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics + * for each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + */ + def nodeBinSeqOp( + treeIndex: Int, + nodeInfo: RandomForest.NodeIndexInfo, + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Unit = { + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, + instanceWeight, featuresForNode) + } + } + } /** * Performs a sequential aggregation over a partition. @@ -495,20 +540,25 @@ object DecisionTree extends Serializable with Logging { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, bins, metadata.unorderedFeatures) - val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null) - // If the example does not reach a node in this group, then nodeIndex = null. - if (nodeInfo != null) { - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val featuresForNode = nodeInfo.featureSubset - val instanceWeight = baggedPoint.subsampleWeights(treeIndex) - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) - } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, - instanceWeight, featuresForNode) - } - } + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + + agg + } + + /** + * Do the same thing as binSeqOp, but with nodeIdCache. + */ + def binSeqOpWithNodeIdCache( + agg: Array[DTStatsAggregator], + dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } + agg } @@ -532,6 +582,14 @@ object DecisionTree extends Serializable with Logging { Some(mutableNodeToFeatures.toMap) } + // array of nodes to train indexed by node index in group + val nodes = new Array[Node](numNodes) + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node + } + } + // Calculate best splits for all nodes in the group timer.start("chooseSplits") @@ -543,7 +601,26 @@ object DecisionTree extends Serializable with Logging { // Finally, only best Splits for nodes are collected to driver to construct decision tree. val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val nodeToBestSplits = + + val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } else { input.mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator @@ -560,7 +637,10 @@ object DecisionTree extends Serializable with Logging { // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator - }.reduceByKey((a, b) => a.merge(b)) + } + } + + val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)) .map { case (nodeIndex, aggStats) => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) @@ -568,12 +648,19 @@ object DecisionTree extends Serializable with Logging { // find best split for each node val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(aggStats, splits, featuresForNode) + binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) (nodeIndex, (split, stats, predict)) }.collectAsMap() timer.stop("chooseSplits") + val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { + Array.fill[mutable.Map[Int, NodeIndexUpdater]]( + metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) + } else { + null + } + // Iterate over all nodes in this group. nodesForGroup.foreach { case (treeIndex, nodesForTree) => nodesForTree.foreach { node => @@ -587,17 +674,37 @@ object DecisionTree extends Serializable with Logging { // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) assert(node.id == nodeIndex) - node.predict = predict.predict + node.predict = predict node.isLeaf = isLeaf node.stats = Some(stats) + node.impurity = stats.impurity logDebug("Node = " + node) if (!isLeaf) { node.split = Some(split) - node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) - node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) - nodeQueue.enqueue((treeIndex, node.leftNode.get)) - nodeQueue.enqueue((treeIndex, node.rightNode.get)) + val childIsLeaf = (Node.indexToLevel(nodeIndex) + 1) == metadata.maxDepth + val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) + val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) + node.leftNode = Some(Node(Node.leftChildIndex(nodeIndex), + stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), + stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + + if (nodeIdCache.nonEmpty) { + val nodeIndexUpdater = NodeIndexUpdater( + split = split, + nodeIndex = nodeIndex) + nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) + } + + // enqueue left child and right child if they are not leaves + if (!leftChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.leftNode.get)) + } + if (!rightChildIsLeaf) { + nodeQueue.enqueue((treeIndex, node.rightNode.get)) + } + logDebug("leftChildIndex = " + node.leftNode.get.id + ", impurity = " + stats.leftImpurity) logDebug("rightChildIndex = " + node.rightNode.get.id + @@ -606,6 +713,10 @@ object DecisionTree extends Serializable with Logging { } } + if (nodeIdCache.nonEmpty) { + // Update the cache if needed. + nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins) + } } /** @@ -617,7 +728,8 @@ object DecisionTree extends Serializable with Logging { private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata): InformationGainStats = { + metadata: DecisionTreeMetadata, + impurity: Double): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -630,11 +742,6 @@ object DecisionTree extends Serializable with Logging { val totalCount = leftCount + rightCount - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - - val impurity = parentNodeAgg.calculate() - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -649,7 +756,18 @@ object DecisionTree extends Serializable with Logging { return InformationGainStats.invalidInformationGainStats } - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity) + // calculate left and right predict + val leftPredict = calculatePredict(leftImpurityCalculator) + val rightPredict = calculatePredict(rightImpurityCalculator) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, + leftPredict, rightPredict) + } + + private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { + val predict = impurityCalculator.predict + val prob = impurityCalculator.prob(predict) + new Predict(predict, prob) } /** @@ -657,17 +775,17 @@ object DecisionTree extends Serializable with Logging { * Note that this function is called only once for each node. * @param leftImpurityCalculator left node aggregates for a split * @param rightImpurityCalculator right node aggregates for a split - * @return predict value for current node + * @return predict value and impurity for current node */ - private def calculatePredict( + private def calculatePredictImpurity( leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): Predict = { + rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { val parentNodeAgg = leftImpurityCalculator.copy parentNodeAgg.add(rightImpurityCalculator) - val predict = parentNodeAgg.predict - val prob = parentNodeAgg.prob(predict) + val predict = calculatePredict(parentNodeAgg) + val impurity = parentNodeAgg.calculate() - new Predict(predict, prob) + (predict, impurity) } /** @@ -678,10 +796,16 @@ object DecisionTree extends Serializable with Logging { private def binsToBestSplit( binAggregates: DTStatsAggregator, splits: Array[Array[Split]], - featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { + featuresForNode: Option[Array[Int]], + node: Node): (Split, InformationGainStats, Predict) = { - // calculate predict only once - var predict: Option[Predict] = None + // calculate predict and impurity if current node is top node + val level = Node.indexToLevel(node.id) + var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) { + None + } else { + Some((node.predict, node.impurity)) + } // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = @@ -708,9 +832,10 @@ object DecisionTree extends Serializable with Logging { val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -722,9 +847,10 @@ object DecisionTree extends Serializable with Logging { Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) @@ -794,9 +920,10 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) + predictWithImpurity = Some(predictWithImpurity.getOrElse( + calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata) + rightChildStats, binAggregates.metadata, predictWithImpurity.get._2) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -807,9 +934,7 @@ object DecisionTree extends Serializable with Logging { } }.maxBy(_._2.gain) - assert(predict.isDefined, "must calculate predict for each node") - - (bestSplit, bestSplitStats, predict.get) + (bestSplit, bestSplitStats, predictWithImpurity.get._1) } /** @@ -874,32 +999,39 @@ object DecisionTree extends Serializable with Logging { // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - val numSplits = metadata.numSplits(featureIndex) - val numBins = metadata.numBins(featureIndex) if (metadata.isContinuous(featureIndex)) { - val numSamples = sampledInput.length + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)) + val featureSplits = findSplitsForContinuousFeature(featureSamples, + metadata, featureIndex) + + val numSplits = featureSplits.length + val numBins = numSplits + 1 + logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits") splits(featureIndex) = new Array[Split](numSplits) bins(featureIndex) = new Array[Bin](numBins) - val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) - logDebug("stride = " + stride) - for (splitIndex <- 0 until numSplits) { - val sampleIndex = splitIndex * stride.toInt - // Set threshold halfway in between 2 samples. - val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 + + var splitIndex = 0 + while (splitIndex < numSplits) { + val threshold = featureSplits(splitIndex) splits(featureIndex)(splitIndex) = new Split(featureIndex, threshold, Continuous, List()) + splitIndex += 1 } bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), splits(featureIndex)(0), Continuous, Double.MinValue) - for (splitIndex <- 1 until numSplits) { + + splitIndex = 1 + while (splitIndex < numSplits) { bins(featureIndex)(splitIndex) = new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), Continuous, Double.MinValue) + splitIndex += 1 } bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) } else { + val numSplits = metadata.numSplits(featureIndex) + val numBins = metadata.numBins(featureIndex) // Categorical feature val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { @@ -976,4 +1108,77 @@ object DecisionTree extends Serializable with Logging { categories } + /** + * Find splits for a continuous feature + * NOTE: Returned number of splits is set based on `featureSamples` and + * could be different from the specified `numSplits`. + * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly. + * @param featureSamples feature values of each sample + * @param metadata decision tree metadata + * NOTE: `metadata.numbins` will be changed accordingly + * if there are not enough splits to be found + * @param featureIndex feature index to find splits + * @return array of splits + */ + private[tree] def findSplitsForContinuousFeature( + featureSamples: Array[Double], + metadata: DecisionTreeMetadata, + featureIndex: Int): Array[Double] = { + require(metadata.isContinuous(featureIndex), + "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") + + val splits = { + val numSplits = metadata.numSplits(featureIndex) + + // get count for each distinct value + val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) => + m + ((x, m.getOrElse(x, 0) + 1)) + } + // sort distinct values + val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray + + // if possible splits is not enough or just enough, just return all possible splits + val possibleSplits = valueCounts.length + if (possibleSplits <= numSplits) { + valueCounts.map(_._1) + } else { + // stride between splits + val stride: Double = featureSamples.length.toDouble / (numSplits + 1) + logDebug("stride = " + stride) + + // iterate `valueCount` to find splits + val splits = new ArrayBuffer[Double] + var index = 1 + // currentCount: sum of counts of values that have been visited + var currentCount = valueCounts(0)._2 + // targetCount: target value for `currentCount`. + // If `currentCount` is closest value to `targetCount`, + // then current value is a split threshold. + // After finding a split threshold, `targetCount` is added by stride. + var targetCount = stride + while (index < valueCounts.length) { + val previousCount = currentCount + currentCount += valueCounts(index)._2 + val previousGap = math.abs(previousCount - targetCount) + val currentGap = math.abs(currentCount - targetCount) + // If adding count of current value to currentCount + // makes the gap between currentCount and targetCount smaller, + // previous value is a split threshold. + if (previousGap < currentGap) { + splits.append(valueCounts(index - 1)._1) + targetCount += stride + } + index += 1 + } + + splits.toArray + } + } + + assert(splits.length > 0) + // set number of splits accordingly + metadata.setNumSplits(featureIndex, splits.length) + + splits + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala new file mode 100644 index 0000000000000..61f6b1313f82e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.BoostingStrategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.impl.TimeTracker +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * :: Experimental :: + * A class that implements + * [[http://en.wikipedia.org/wiki/Gradient_boosting Stochastic Gradient Boosting]] + * for regression and binary classification. + * + * The implementation is based upon: + * J.H. Friedman. "Stochastic Gradient Boosting." 1999. + * + * Notes on Gradient Boosting vs. TreeBoost: + * - This implementation is for Stochastic Gradient Boosting, not for TreeBoost. + * - Both algorithms learn tree ensembles by minimizing loss functions. + * - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes + * based on the loss function, whereas the original gradient boosting method does not. + * - When the loss is SquaredError, these methods give the same result, but they could differ + * for other loss functions. + * + * @param boostingStrategy Parameters for the gradient boosting algorithm. + */ +@Experimental +class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) + extends Serializable with Logging { + + /** + * Method to train a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return a gradient boosted trees model that can be used for prediction + */ + def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case Regression => GradientBoostedTrees.boost(input, boostingStrategy) + case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, boostingStrategy) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#run]]. + */ + def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + run(input.rdd) + } +} + + +object GradientBoostedTrees extends Logging { + + /** + * Method to train a gradient boosting model. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * For classification, labels should take values {0, 1, ..., numClasses-1}. + * For regression, labels are real numbers. + * @param boostingStrategy Configuration options for the boosting algorithm. + * @return a gradient boosted trees model that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + new GradientBoostedTrees(boostingStrategy).run(input) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees$#train]] + */ + def train( + input: JavaRDD[LabeledPoint], + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + train(input.rdd, boostingStrategy) + } + + /** + * Internal method for performing regression using trees as base learners. + * @param input training dataset + * @param boostingStrategy boosting parameters + * @return a gradient boosted trees model that can be used for prediction + */ + private def boost( + input: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + + val timer = new TimeTracker() + timer.start("total") + timer.start("init") + + boostingStrategy.assertValid() + + // Initialize gradient boosting parameters + val numIterations = boostingStrategy.numIterations + val baseLearners = new Array[DecisionTreeModel](numIterations) + val baseLearnerWeights = new Array[Double](numIterations) + val loss = boostingStrategy.loss + val learningRate = boostingStrategy.learningRate + // Prepare strategy for individual trees, which use regression with variance impurity. + val treeStrategy = boostingStrategy.treeStrategy.copy + treeStrategy.algo = Regression + treeStrategy.impurity = Variance + treeStrategy.assertValid() + + // Cache input + if (input.getStorageLevel == StorageLevel.NONE) { + input.persist(StorageLevel.MEMORY_AND_DISK) + } + + timer.stop("init") + + logDebug("##########") + logDebug("Building tree 0") + logDebug("##########") + var data = input + + // Initialize tree + timer.start("building tree 0") + val firstTreeModel = new DecisionTree(treeStrategy).run(data) + baseLearners(0) = firstTreeModel + baseLearnerWeights(0) = 1.0 + val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) + logDebug("error of gbt = " + loss.computeError(startingModel, input)) + // Note: A model of type regression is used since we require raw prediction + timer.stop("building tree 0") + + // psuedo-residual for second iteration + data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), + point.features)) + + var m = 1 + while (m < numIterations) { + timer.start(s"building tree $m") + logDebug("###################################################") + logDebug("Gradient boosting tree iteration " + m) + logDebug("###################################################") + val model = new DecisionTree(treeStrategy).run(data) + timer.stop(s"building tree $m") + // Create partial model + baseLearners(m) = model + // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. + // Technically, the weight should be optimized for the particular loss. + // However, the behavior should be reasonable, though not optimal. + baseLearnerWeights(m) = learningRate + // Note: A model of type regression is used since we require raw prediction + val partialModel = new GradientBoostedTreesModel( + Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) + logDebug("error of gbt = " + loss.computeError(partialModel, input)) + // Update data with pseudo-residuals + data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), + point.features)) + m += 1 + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index fa7a26f17c3ca..482d3395516e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -17,17 +17,18 @@ package org.apache.spark.mllib.tree -import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.collection.JavaConverters._ import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, NodeIdCache, + TimeTracker, TreePoint} import org.apache.spark.mllib.tree.impurity.Impurities import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -36,7 +37,8 @@ import org.apache.spark.util.Utils /** * :: Experimental :: - * A class which implements a random forest learning algorithm for classification and regression. + * A class that implements a [[http://en.wikipedia.org/wiki/Random_forest Random Forest]] + * learning algorithm for classification and regression. * It supports both continuous and categorical features. * * The settings for featureSubsetStrategy are based on the following references: @@ -59,7 +61,7 @@ import org.apache.spark.util.Utils * if numTrees == 1, set to "all"; * if numTrees > 1 (forest) set to "sqrt" for classification and * to "onethird" for regression. - * @param seed Random seed for bootstrapping and choosing feature subsets. + * @param seed Random seed for bootstrapping and choosing feature subsets. */ @Experimental private class RandomForest ( @@ -69,6 +71,47 @@ private class RandomForest ( private val seed: Int) extends Serializable with Logging { + /* + ALGORITHM + This is a sketch of the algorithm to help new developers. + + The algorithm partitions data by instances (rows). + On each iteration, the algorithm splits a set of nodes. In order to choose the best split + for a given node, sufficient statistics are collected from the distributed data. + For each node, the statistics are collected to some worker node, and that worker selects + the best split. + + This setup requires discretization of continuous features. This binning is done in the + findSplitsBins() method during initialization, after which each continuous feature becomes + an ordered discretized feature with at most maxBins possible values. + + The main loop in the algorithm operates on a queue of nodes (nodeQueue). These nodes + lie at the periphery of the tree being trained. If multiple trees are being trained at once, + then this queue contains nodes from all of them. Each iteration works roughly as follows: + On the master node: + - Some number of nodes are pulled off of the queue (based on the amount of memory + required for their sufficient statistics). + - For random forests, if featureSubsetStrategy is not "all," then a subset of candidate + features are chosen for each node. See method selectNodesToSplit(). + On worker nodes, via method findBestSplits(): + - The worker makes one pass over its subset of instances. + - For each (tree, node, feature, split) tuple, the worker collects statistics about + splitting. Note that the set of (tree, node) pairs is limited to the nodes selected + from the queue for this iteration. The set of features considered can also be limited + based on featureSubsetStrategy. + - For each node, the statistics for that node are aggregated to a particular worker + via reduceByKey(). The designated worker chooses the best (feature, split) pair, + or chooses to stop splitting if the stopping criteria are met. + On the master node: + - The master collects all decisions about splitting nodes and updates the model. + - The updated model is passed to the workers on the next iteration. + This process continues until the node queue is empty. + + Most of the methods in this implementation support the statistics aggregation, which is + the heaviest part of the computation. In general, this implementation is bound by either + the cost of statistics computation on workers or by communicating the sufficient statistics. + */ + strategy.assertValid() require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), @@ -78,9 +121,9 @@ private class RandomForest ( /** * Method to train a decision tree model over an RDD * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] - * @return RandomForestModel that can be used for prediction + * @return a random forest model that can be used for prediction */ - def train(input: RDD[LabeledPoint]): RandomForestModel = { + def run(input: RDD[LabeledPoint]): RandomForestModel = { val timer = new TimeTracker() @@ -111,11 +154,20 @@ private class RandomForest ( // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - val baggedInput = if (numTrees > 1) { - BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed) - } else { - BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) - }.persist(StorageLevel.MEMORY_AND_DISK) + + val (subsample, withReplacement) = { + // TODO: Have a stricter check for RF in the strategy + val isRandomForest = numTrees > 1 + if (isRandomForest) { + (1.0, true) + } else { + (strategy.subsamplingRate, false) + } + } + + val baggedInput + = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed) + .persist(StorageLevel.MEMORY_AND_DISK) // depth of the decision tree val maxDepth = strategy.maxDepth @@ -150,6 +202,19 @@ private class RandomForest ( * in lower levels). */ + // Create an RDD of node Id cache. + // At first, all the rows belong to the root nodes (node Id == 1). + val nodeIdCache = if (strategy.useNodeIdCache) { + Some(NodeIdCache.init( + data = baggedInput, + numTrees = numTrees, + checkpointDir = strategy.checkpointDir, + checkpointInterval = strategy.checkpointInterval, + initVal = 1)) + } else { + None + } + // FIFO queue of nodes to train: (treeIndex, node) val nodeQueue = new mutable.Queue[(Int, Node)]() @@ -172,17 +237,24 @@ private class RandomForest ( // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) + treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) timer.stop("findBestSplits") } + baggedInput.unpersist() + timer.stop("total") logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + // Delete any remaining checkpoints used for node Id cache. + if (nodeIdCache.nonEmpty) { + nodeIdCache.get.deleteAllCheckpoints() + } + val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) - RandomForestModel.build(trees) + new RandomForestModel(strategy.algo, trees) } } @@ -200,10 +272,9 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "sqrt". * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], @@ -214,7 +285,7 @@ object RandomForest extends Serializable with Logging { require(strategy.algo == Classification, s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -231,8 +302,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "sqrt". * @param impurity Criterion used for information gain calculation. * Supported values: "gini" (recommended) or "entropy". * @param maxDepth Maximum depth of the tree. @@ -241,7 +311,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainClassifier( input: RDD[LabeledPoint], @@ -288,10 +358,9 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "onethird". * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], @@ -302,7 +371,7 @@ object RandomForest extends Serializable with Logging { require(strategy.algo == Regression, s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) - rf.train(input) + rf.run(input) } /** @@ -318,8 +387,7 @@ object RandomForest extends Serializable with Logging { * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". * If "auto" is set, this parameter is set based on numTrees: * if numTrees == 1, set to "all"; - * if numTrees > 1 (forest) set to "sqrt" for classification and - * to "onethird" for regression. + * if numTrees > 1 (forest) set to "onethird". * @param impurity Criterion used for information gain calculation. * Supported values: "variance". * @param maxDepth Maximum depth of the tree. @@ -328,7 +396,7 @@ object RandomForest extends Serializable with Logging { * @param maxBins maximum number of bins used for splitting features * (suggested value: 100) * @param seed Random seed for bootstrapping and choosing feature subsets. - * @return RandomForestModel that can be used for prediction + * @return a random forest model that can be used for prediction */ def trainRegressor( input: RDD[LabeledPoint], @@ -448,5 +516,4 @@ object RandomForest extends Serializable with Logging { 3 * totalBins } } - } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala new file mode 100644 index 0000000000000..e703adbdbfbb3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import scala.beans.BeanProperty + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} + +/** + * :: Experimental :: + * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. + * + * @param treeStrategy Parameters for the tree algorithm. We support regression and binary + * classification for boosting. Impurity setting will be ignored. + * @param loss Loss function used for minimization during gradient boosting. + * @param numIterations Number of iterations of boosting. In other words, the number of + * weak hypotheses used in the final model. + * @param learningRate Learning rate for shrinking the contribution of each estimator. The + * learning rate should be between in the interval (0, 1] + */ +@Experimental +case class BoostingStrategy( + // Required boosting parameters + @BeanProperty var treeStrategy: Strategy, + @BeanProperty var loss: Loss, + // Optional boosting parameters + @BeanProperty var numIterations: Int = 100, + @BeanProperty var learningRate: Double = 0.1) extends Serializable { + + /** + * Check validity of parameters. + * Throws exception if invalid. + */ + private[tree] def assertValid(): Unit = { + treeStrategy.algo match { + case Classification => + require(treeStrategy.numClassesForClassification == 2, + "Only binary classification is supported for boosting.") + case Regression => + // nothing + case _ => + throw new IllegalArgumentException( + s"BoostingStrategy given invalid algo parameter: ${treeStrategy.algo}." + + s" Valid settings are: Classification, Regression.") + } + require(learningRate > 0 && learningRate <= 1, + "Learning rate should be in range (0, 1]. Provided learning rate is " + s"$learningRate.") + } +} + +@Experimental +object BoostingStrategy { + + /** + * Returns default configuration for the boosting algorithm + * @param algo Learning goal. Supported: + * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]], + * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]] + * @return Configuration for boosting algorithm + */ + def defaultParams(algo: String): BoostingStrategy = { + val treeStrategy = Strategy.defaultStrategy(algo) + treeStrategy.maxDepth = 3 + algo match { + case "Classification" => + treeStrategy.numClassesForClassification = 2 + new BoostingStrategy(treeStrategy, LogLoss) + case "Regression" => + new BoostingStrategy(treeStrategy, SquaredError) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the boosting.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala new file mode 100644 index 0000000000000..b5bf732d1b33a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +/** + * Enum to select ensemble combining strategy for base learners + */ +private[tree] object EnsembleCombiningStrategy extends Enumeration { + type EnsembleCombiningStrategy = Value + val Average, Sum, Vote = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index caaccbfb8ad16..d75f38433c081 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.tree.configuration +import scala.beans.BeanProperty import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental @@ -43,7 +44,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * for choosing how to split on features at each node. * More bins give higher granularity. * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported: - * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] + * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]] * @param categoricalFeaturesInfo A map storing information about the categorical variables and the * number of discrete values they take. For example, an entry (n -> * k) implies the feature n is categorical with k categories 0, @@ -58,31 +59,35 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * this split will not be considered as a valid split. * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 256 MB. + * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will + * maintain a separate RDD of node Id cache for each row. + * @param checkpointDir If the node Id cache is used, it will help to checkpoint + * the node Id cache periodically. This is the checkpoint directory + * to be used for the node Id cache. + * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. + * E.g. 10 means that the cache will get checkpointed every 10 updates. */ @Experimental class Strategy ( - val algo: Algo, - val impurity: Impurity, - val maxDepth: Int, - val numClassesForClassification: Int = 2, - val maxBins: Int = 32, - val quantileCalculationStrategy: QuantileStrategy = Sort, - val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), - val minInstancesPerNode: Int = 1, - val minInfoGain: Double = 0.0, - val maxMemoryInMB: Int = 256) extends Serializable { + @BeanProperty var algo: Algo, + @BeanProperty var impurity: Impurity, + @BeanProperty var maxDepth: Int, + @BeanProperty var numClassesForClassification: Int = 2, + @BeanProperty var maxBins: Int = 32, + @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort, + @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](), + @BeanProperty var minInstancesPerNode: Int = 1, + @BeanProperty var minInfoGain: Double = 0.0, + @BeanProperty var maxMemoryInMB: Int = 256, + @BeanProperty var subsamplingRate: Double = 1, + @BeanProperty var useNodeIdCache: Boolean = false, + @BeanProperty var checkpointDir: Option[String] = None, + @BeanProperty var checkpointInterval: Int = 10) extends Serializable { - if (algo == Classification) { - require(numClassesForClassification >= 2) - } - require(minInstancesPerNode >= 1, - s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") - require(maxMemoryInMB <= 10240, - s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") - - val isMulticlassClassification = + def isMulticlassClassification = algo == Classification && numClassesForClassification > 2 - val isMulticlassWithCategoricalFeatures + def isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) /** @@ -99,6 +104,23 @@ class Strategy ( categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) } + /** + * Sets Algorithm using a String. + */ + def setAlgo(algo: String): Unit = algo match { + case "Classification" => setAlgo(Classification) + case "Regression" => setAlgo(Regression) + } + + /** + * Sets categoricalFeaturesInfo using a Java Map. + */ + def setCategoricalFeaturesInfo( + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer]): Unit = { + setCategoricalFeaturesInfo( + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap) + } + /** * Check validity of parameters. * Throws exception if invalid. @@ -130,6 +152,33 @@ class Strategy ( s"DecisionTree Strategy given invalid categoricalFeaturesInfo setting:" + s" feature $feature has $arity categories. The number of categories should be >= 2.") } + require(minInstancesPerNode >= 1, + s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode") + require(maxMemoryInMB <= 10240, + s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB") } + /** Returns a shallow copy of this instance. */ + def copy: Strategy = { + new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins, + quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, + maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) + } +} + +@Experimental +object Strategy { + + /** + * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]] + * @param algo "Classification" or "Regression" + */ + def defaultStrategy(algo: String): Strategy = algo match { + case "Classification" => + new Strategy(algo = Classification, impurity = Gini, maxDepth = 10, + numClassesForClassification = 2) + case "Regression" => + new Strategy(algo = Regression, impurity = Variance, maxDepth = 10, + numClassesForClassification = 0) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala index 937c8a2ac5836..089010c81ffb6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala @@ -17,18 +17,18 @@ package org.apache.spark.mllib.tree.impl -import cern.jet.random.Poisson -import cern.jet.random.engine.DRand +import org.apache.commons.math3.distribution.PoissonDistribution import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom /** * Internal representation of a datapoint which belongs to several subsamples of the same dataset, * particularly for bagging (e.g., for random forests). * * This holds one instance, as well as an array of weights which represent the (weighted) - * number of times which this instance appears in each subsample. + * number of times which this instance appears in each subsamplingRate. * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. * @@ -45,27 +45,71 @@ private[tree] object BaggedPoint { /** * Convert an input dataset into its BaggedPoint representation, - * choosing subsample counts for each instance. - * Each subsample has the same number of instances as the original dataset, - * and is created by subsampling with replacement. - * @param input Input dataset. - * @param numSubsamples Number of subsamples of this RDD to take. - * @param seed Random seed. - * @return BaggedPoint dataset representation + * choosing subsamplingRate counts for each instance. + * Each subsamplingRate has the same number of instances as the original dataset, + * and is created by subsampling without replacement. + * @param input Input dataset. + * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param numSubsamples Number of subsamples of this RDD to take. + * @param withReplacement Sampling with/without replacement. + * @param seed Random seed. + * @return BaggedPoint dataset representation. */ - def convertToBaggedRDD[Datum]( + def convertToBaggedRDD[Datum] ( input: RDD[Datum], + subsamplingRate: Double, numSubsamples: Int, + withReplacement: Boolean, seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = { + if (withReplacement) { + convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) + } else { + if (numSubsamples == 1 && subsamplingRate == 1.0) { + convertToBaggedRDDWithoutSampling(input) + } else { + convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) + } + } + } + + private def convertToBaggedRDDSamplingWithoutReplacement[Datum] ( + input: RDD[Datum], + subsamplingRate: Double, + numSubsamples: Int, + seed: Int): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val rng = new XORShiftRandom + rng.setSeed(seed + partitionIndex + 1) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + val x = rng.nextDouble() + subsampleWeights(subsampleIndex) = { + if (x < subsamplingRate) 1.0 else 0.0 + } + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + private def convertToBaggedRDDSamplingWithReplacement[Datum] ( + input: RDD[Datum], + subsample: Double, + numSubsamples: Int, + seed: Int): RDD[BaggedPoint[Datum]] = { input.mapPartitionsWithIndex { (partitionIndex, instances) => - // TODO: Support different sampling rates, and sampling without replacement. // Use random seed = seed + partitionIndex + 1 to make generation reproducible. - val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1)) + val poisson = new PoissonDistribution(subsample) + poisson.reseedRandomGenerator(seed + partitionIndex + 1) instances.map { instance => val subsampleWeights = new Array[Double](numSubsamples) var subsampleIndex = 0 while (subsampleIndex < numSubsamples) { - subsampleWeights(subsampleIndex) = poisson.nextInt() + subsampleWeights(subsampleIndex) = poisson.sample() subsampleIndex += 1 } new BaggedPoint(instance, subsampleWeights) @@ -73,7 +117,8 @@ private[tree] object BaggedPoint { } } - def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { + private def convertToBaggedRDDWithoutSampling[Datum] ( + input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { input.map(datum => new BaggedPoint(datum, Array(1.0))) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 55f422dff0d71..ce8825cc03229 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -64,12 +64,6 @@ private[tree] class DTStatsAggregator( numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) } - /** - * Indicator for each feature of whether that feature is an unordered feature. - * TODO: Is Array[Boolean] any faster? - */ - def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) - /** * Total number of elements stored in this aggregator */ @@ -128,21 +122,13 @@ private[tree] class DTStatsAggregator( * Pre-compute feature offset for use with [[featureUpdate]]. * For ordered features only. */ - def getFeatureOffset(featureIndex: Int): Int = { - require(!isUnordered(featureIndex), - s"DTStatsAggregator.getFeatureOffset is for ordered features only, but was called" + - s" for unordered feature $featureIndex.") - featureOffsets(featureIndex) - } + def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) /** * Pre-compute feature offset for use with [[featureUpdate]]. * For unordered features only. */ def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - require(isUnordered(featureIndex), - s"DTStatsAggregator.getLeftRightFeatureOffsets is for unordered features only," + - s" but was called for ordered feature $featureIndex.") val baseOffset = featureOffsets(featureIndex) (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 212dce25236e0..5bc0f2635c6b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable +import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ @@ -75,6 +76,17 @@ private[tree] class DecisionTreeMetadata( numBins(featureIndex) - 1 } + + /** + * Set number of splits for a continuous feature. + * For a continuous feature, number of bins is number of splits plus 1. + */ + def setNumSplits(featureIndex: Int, numSplits: Int) { + require(isContinuous(featureIndex), + s"Only number of bin for a continuous feature can be set.") + numBins(featureIndex) = numSplits + 1 + } + /** * Indicates if feature subsampling is being used. */ @@ -82,7 +94,7 @@ private[tree] class DecisionTreeMetadata( } -private[tree] object DecisionTreeMetadata { +private[tree] object DecisionTreeMetadata extends Logging { /** * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. @@ -103,6 +115,10 @@ private[tree] object DecisionTreeMetadata { } val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + if (maxPossibleBins < strategy.maxBins) { + logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + + s" (= number of training instances)") + } // We check the number of bins here against maxPossibleBins. // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala new file mode 100644 index 0000000000000..83011b48b7d9b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.model.{Bin, Node, Split} + +/** + * :: DeveloperApi :: + * This is used by the node id cache to find the child id that a data point would belong to. + * @param split Split information. + * @param nodeIndex The current node index of a data point that this will update. + */ +@DeveloperApi +private[tree] case class NodeIndexUpdater( + split: Split, + nodeIndex: Int) { + /** + * Determine a child node index based on the feature value and the split. + * @param binnedFeatures Binned feature values. + * @param bins Bin information to convert the bin indices to approximate feature values. + * @return Child node index to update to. + */ + def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = { + if (split.featureType == Continuous) { + val featureIndex = split.feature + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + if (featureValueUpperBound <= split.threshold) { + Node.leftChildIndex(nodeIndex) + } else { + Node.rightChildIndex(nodeIndex) + } + } else { + if (split.categories.contains(binnedFeatures(split.feature).toDouble)) { + Node.leftChildIndex(nodeIndex) + } else { + Node.rightChildIndex(nodeIndex) + } + } + } +} + +/** + * :: DeveloperApi :: + * A given TreePoint would belong to a particular node per tree. + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index + * in each tree. Initially, values should all be 1 for root node. + * The nodeIdsForInstances RDD needs to be updated at each iteration. + * @param nodeIdsForInstances The initial values in the cache + * (should be an Array of all 1's (meaning the root nodes)). + * @param checkpointDir The checkpoint directory where + * the checkpointed files will be stored. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + */ +@DeveloperApi +private[tree] class NodeIdCache( + var nodeIdsForInstances: RDD[Array[Int]], + val checkpointDir: Option[String], + val checkpointInterval: Int) { + + // Keep a reference to a previous node Ids for instances. + // Because we will keep on re-persisting updated node Ids, + // we want to unpersist the previous RDD. + private var prevNodeIdsForInstances: RDD[Array[Int]] = null + + // To keep track of the past checkpointed RDDs. + private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() + private var rddUpdateCount = 0 + + // If a checkpoint directory is given, and there's no prior checkpoint directory, + // then set the checkpoint directory with the given one. + if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) { + nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get) + } + + /** + * Update the node index values in the cache. + * This updates the RDD and its lineage. + * TODO: Passing bin information to executors seems unnecessary and costly. + * @param data The RDD of training rows. + * @param nodeIdUpdaters A map of node index updaters. + * The key is the indices of nodes that we want to update. + * @param bins Bin information needed to find child node indices. + */ + def updateNodeIndices( + data: RDD[BaggedPoint[TreePoint]], + nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], + bins: Array[Array[Bin]]): Unit = { + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } + + prevNodeIdsForInstances = nodeIdsForInstances + nodeIdsForInstances = data.zip(nodeIdsForInstances).map { + dataPoint => { + var treeId = 0 + while (treeId < nodeIdUpdaters.length) { + val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null) + if (nodeIdUpdater != null) { + val newNodeIndex = nodeIdUpdater.updateNodeIndex( + binnedFeatures = dataPoint._1.datum.binnedFeatures, + bins = bins) + dataPoint._2(treeId) = newNodeIndex + } + + treeId += 1 + } + + dataPoint._2 + } + } + + // Keep on persisting new ones. + nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) + rddUpdateCount += 1 + + // Handle checkpointing if the directory is not None. + if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty && + (rddUpdateCount % checkpointInterval) == 0) { + // Let's see if we can delete previous checkpoints. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // We can delete the oldest checkpoint iff + // the next checkpoint actually exists in the file system. + if (checkpointQueue.get(1).get.getCheckpointFile != None) { + val old = checkpointQueue.dequeue() + + // Since the old checkpoint is not deleted by Spark, + // we'll manually delete it here. + val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) + fs.delete(new Path(old.getCheckpointFile.get), true) + } else { + canDelete = false + } + } + + nodeIdsForInstances.checkpoint() + checkpointQueue.enqueue(nodeIdsForInstances) + } + } + + /** + * Call this after training is finished to delete any remaining checkpoints. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.size > 0) { + val old = checkpointQueue.dequeue() + if (old.getCheckpointFile != None) { + val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) + fs.delete(new Path(old.getCheckpointFile.get), true) + } + } + } +} + +@DeveloperApi +private[tree] object NodeIdCache { + /** + * Initialize the node Id cache with initial node Id values. + * @param data The RDD of training rows. + * @param numTrees The number of trees that we want to create cache for. + * @param checkpointDir The checkpoint directory where the checkpointed files will be stored. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + * @param initVal The initial values in the cache. + * @return A node Id cache containing an RDD of initial root node Indices. + */ + def init( + data: RDD[BaggedPoint[TreePoint]], + numTrees: Int, + checkpointDir: Option[String], + checkpointInterval: Int, + initVal: Int = 1): NodeIdCache = { + new NodeIdCache( + data.map(_ => Array.fill[Int](numTrees)(initVal)), + checkpointDir, + checkpointInterval) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala new file mode 100644 index 0000000000000..d1bde15e6b150 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.TreeEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Class for absolute error loss calculation (for regression). + * + * The absolute (L1) error is defined as: + * |y - F(x)| + * where y is the label and F(x) is the model prediction for features x. + */ +@DeveloperApi +object AbsoluteError extends Loss { + + /** + * Method to calculate the gradients for the gradient boosting calculation for least + * absolute error calculation. + * The gradient with respect to F(x) is: sign(F(x) - y) + * @param model Ensemble model + * @param point Instance of the training dataset + * @return Loss gradient + */ + override def gradient( + model: TreeEnsembleModel, + point: LabeledPoint): Double = { + if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0 + } + + /** + * Method to calculate loss of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Ensemble model + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return Mean absolute error of model on data + */ + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = model.predict(y.features) - y.label + math.abs(err) + }.mean() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala new file mode 100644 index 0000000000000..7ce9fa6f86c42 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.TreeEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Class for log loss calculation (for classification). + * This uses twice the binomial negative log likelihood, called "deviance" in Friedman (1999). + * + * The log loss is defined as: + * 2 log(1 + exp(-2 y F(x))) + * where y is a label in {-1, 1} and F(x) is the model prediction for features x. + */ +@DeveloperApi +object LogLoss extends Loss { + + /** + * Method to calculate the loss gradients for the gradient boosting calculation for binary + * classification + * The gradient with respect to F(x) is: - 4 y / (1 + exp(2 y F(x))) + * @param model Ensemble model + * @param point Instance of the training dataset + * @return Loss gradient + */ + override def gradient( + model: TreeEnsembleModel, + point: LabeledPoint): Double = { + val prediction = model.predict(point.features) + - 4.0 * point.label / (1.0 + math.exp(2.0 * point.label * prediction)) + } + + /** + * Method to calculate loss of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Ensemble model + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return Mean log loss of model on data + */ + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map { case point => + val prediction = model.predict(point.features) + val margin = 2.0 * point.label * prediction + // The following are equivalent to 2.0 * log(1 + exp(-margin)) but are more numerically + // stable. + if (margin >= 0) { + 2.0 * math.log1p(math.exp(-margin)) + } else { + 2.0 * (-margin + math.log1p(math.exp(margin))) + } + }.mean() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala new file mode 100644 index 0000000000000..4bca9039ebe1d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.TreeEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Trait for adding "pluggable" loss functions for the gradient boosting algorithm. + */ +@DeveloperApi +trait Loss extends Serializable { + + /** + * Method to calculate the gradients for the gradient boosting calculation. + * @param model Model of the weak learner. + * @param point Instance of the training dataset. + * @return Loss gradient. + */ + def gradient( + model: TreeEnsembleModel, + point: LabeledPoint): Double + + /** + * Method to calculate error of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Model of the weak learner. + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return + */ + def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala new file mode 100644 index 0000000000000..42c9ead9884b4 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +object Losses { + + def fromString(name: String): Loss = name match { + case "leastSquaresError" => SquaredError + case "leastAbsoluteError" => AbsoluteError + case "logLoss" => LogLoss + case _ => throw new IllegalArgumentException(s"Did not recognize Loss name: $name") + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala new file mode 100644 index 0000000000000..50ecaa2f86f35 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.loss + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.TreeEnsembleModel +import org.apache.spark.rdd.RDD + +/** + * :: DeveloperApi :: + * Class for squared error loss calculation. + * + * The squared (L2) error is defined as: + * (y - F(x))**2 + * where y is the label and F(x) is the model prediction for features x. + */ +@DeveloperApi +object SquaredError extends Loss { + + /** + * Method to calculate the gradients for the gradient boosting calculation for least + * squares error calculation. + * The gradient with respect to F(x) is: - 2 (y - F(x)) + * @param model Ensemble model + * @param point Instance of the training dataset + * @return Loss gradient + */ + override def gradient( + model: TreeEnsembleModel, + point: LabeledPoint): Double = { + 2.0 * (model.predict(point.features) - point.label) + } + + /** + * Method to calculate loss of the base learner for the gradient boosting calculation. + * Note: This method is not used by the gradient boosting algorithm but is useful for debugging + * purposes. + * @param model Ensemble model + * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @return Mean squared error of model on data + */ + override def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = model.predict(y.features) - y.label + err * err + }.mean() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ec1d99ab26f9c..a5760963068c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -18,9 +18,10 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.Vector /** * :: Experimental :: @@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } + + /** + * Predict values for the given data set using the model trained. + * + * @param features JavaRDD representing data points to be predicted + * @return JavaRDD of predictions for each of the given data points + */ + def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { + predict(features.rdd) + } + /** * Get number of nodes in tree, including leaf nodes. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index a89e71e115806..9a50ecb550c38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -26,13 +26,17 @@ import org.apache.spark.annotation.DeveloperApi * @param impurity current node impurity * @param leftImpurity left node impurity * @param rightImpurity right node impurity + * @param leftPredict left node predict + * @param rightPredict right node predict */ @DeveloperApi class InformationGainStats( val gain: Double, val impurity: Double, val leftImpurity: Double, - val rightImpurity: Double) extends Serializable { + val rightImpurity: Double, + val leftPredict: Predict, + val rightPredict: Predict) extends Serializable { override def toString = { "gain = %f, impurity = %f, left impurity = %f, right impurity = %f" @@ -58,5 +62,6 @@ private[tree] object InformationGainStats { * denote that current split doesn't satisfies minimum info gain or * minimum number of instances per node. */ - val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0) + val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, + new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 56c3e25d9285f..2179da8dbe03e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -32,7 +32,8 @@ import org.apache.spark.mllib.linalg.Vector * * @param id integer node id, from 1 * @param predict predicted value at the node - * @param isLeaf whether the leaf is a node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf * @param split split to calculate left and right nodes * @param leftNode left child * @param rightNode right child @@ -41,7 +42,8 @@ import org.apache.spark.mllib.linalg.Vector @DeveloperApi class Node ( val id: Int, - var predict: Double, + var predict: Predict, + var impurity: Double, var isLeaf: Boolean, var split: Option[Split], var leftNode: Option[Node], @@ -49,7 +51,7 @@ class Node ( var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + - "split = " + split + ", stats = " + stats + "impurity = " + impurity + "split = " + split + ", stats = " + stats /** * build the left node and right nodes if not leaf @@ -62,6 +64,7 @@ class Node ( logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) + logDebug("impurity = " + impurity) if (!isLeaf) { leftNode = Some(nodes(Node.leftChildIndex(id))) rightNode = Some(nodes(Node.rightChildIndex(id))) @@ -77,7 +80,7 @@ class Node ( */ def predict(features: Vector) : Double = { if (isLeaf) { - predict + predict.predict } else{ if (split.get.featureType == Continuous) { if (features(split.get.feature) <= split.get.threshold) { @@ -109,7 +112,7 @@ class Node ( } else { Some(rightNode.get.deepCopy()) } - new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) + new Node(id, predict, impurity, isLeaf, split, leftNodeCopy, rightNodeCopy, stats) } /** @@ -154,7 +157,7 @@ class Node ( } val prefix: String = " " * indentFactor if (isLeaf) { - prefix + s"Predict: $predict\n" + prefix + s"Predict: ${predict.predict}\n" } else { prefix + s"If ${splitToString(split.get, left=true)}\n" + leftNode.get.subtreeToString(indentFactor + 1) + @@ -170,7 +173,27 @@ private[tree] object Node { /** * Return a node with the given node id (but nothing else set). */ - def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) + def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, new Predict(Double.MinValue), -1.0, + false, None, None, None, None) + + /** + * Construct a node with nodeIndex, predict, impurity and isLeaf parameters. + * This is used in `DecisionTree.findBestSplits` to construct child nodes + * after finding the best splits for parent nodes. + * Other fields are set at next level. + * @param nodeIndex integer node id, from 1 + * @param predict predicted value at the node + * @param impurity current node impurity + * @param isLeaf whether the node is a leaf + * @return new node instance + */ + def apply( + nodeIndex: Int, + predict: Predict, + impurity: Double, + isLeaf: Boolean): Node = { + new Node(nodeIndex, predict, impurity, isLeaf, None, None, None, None) + } /** * Return the index of the left child of this node. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala index d8476b5cd7bc7..004838ee5ba0e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala @@ -17,12 +17,15 @@ package org.apache.spark.mllib.tree.model +import org.apache.spark.annotation.DeveloperApi + /** * Predicted value for a node * @param predict predicted value * @param prob probability of the label (classification only) */ -private[tree] class Predict( +@DeveloperApi +class Predict( val predict: Double, val prob: Double = 0.0) extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala deleted file mode 100644 index 4d66d6d81caa5..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import scala.collection.mutable - -import org.apache.spark.annotation.Experimental -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.rdd.RDD - -/** - * :: Experimental :: - * Random forest model for classification or regression. - * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make - * aggregate predictions. - * @param trees Trees which make up this forest. This cannot be empty. - * @param algo algorithm type -- classification or regression - */ -@Experimental -class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable { - - require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") - - /** - * Predict values for a single data point. - * - * @param features array representing a single data point - * @return Double prediction from the trained model - */ - def predict(features: Vector): Double = { - algo match { - case Classification => - val predictionToCount = new mutable.HashMap[Int, Int]() - trees.foreach { tree => - val prediction = tree.predict(features).toInt - predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 - } - predictionToCount.maxBy(_._2)._1 - case Regression => - trees.map(_.predict(features)).sum / trees.size - } - } - - /** - * Predict values for the given data set. - * - * @param features RDD representing data points to be predicted - * @return RDD[Double] where each entry contains the corresponding prediction - */ - def predict(features: RDD[Vector]): RDD[Double] = { - features.map(x => predict(x)) - } - - /** - * Get number of trees in forest. - */ - def numTrees: Int = trees.size - - /** - * Get total number of nodes, summed over all trees in the forest. - */ - def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum - - /** - * Print a summary of the model. - */ - override def toString: String = algo match { - case Classification => - s"RandomForestModel classifier with $numTrees trees" - case Regression => - s"RandomForestModel regressor with $numTrees trees" - case _ => throw new IllegalArgumentException( - s"RandomForestModel given unknown algo parameter: $algo.") - } - - /** - * Print the full model to a string. - */ - def toDebugString: String = { - val header = toString + "\n" - header + trees.zipWithIndex.map { case (tree, treeIndex) => - s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) - }.fold("")(_ + _) - } - -} - -private[tree] object RandomForestModel { - - def build(trees: Array[DecisionTreeModel]): RandomForestModel = { - require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") - val algo: Algo = trees(0).algo - require(trees.forall(_.algo == algo), - "RandomForestModel cannot combine trees which have different output types" + - " (classification/regression).") - new RandomForestModel(trees, algo) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala new file mode 100644 index 0000000000000..22997110de8dd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import scala.collection.mutable + +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Represents a random forest model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + */ +@Experimental +class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel]) + extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0), + combiningStrategy = if (algo == Classification) Vote else Average) { + + require(trees.forall(_.algo == algo)) +} + +/** + * :: Experimental :: + * Represents a gradient boosted trees model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + */ +@Experimental +class GradientBoostedTreesModel( + override val algo: Algo, + override val trees: Array[DecisionTreeModel], + override val treeWeights: Array[Double]) + extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) { + + require(trees.size == treeWeights.size) +} + +/** + * Represents a tree ensemble model. + * + * @param algo algorithm for the ensemble model, either Classification or Regression + * @param trees tree ensembles + * @param treeWeights tree ensemble weights + * @param combiningStrategy strategy for combining the predictions, not used for regression. + */ +private[tree] sealed class TreeEnsembleModel( + protected val algo: Algo, + protected val trees: Array[DecisionTreeModel], + protected val treeWeights: Array[Double], + protected val combiningStrategy: EnsembleCombiningStrategy) extends Serializable { + + require(numTrees > 0, "TreeEnsembleModel cannot be created without trees.") + + private val sumWeights = math.max(treeWeights.sum, 1e-15) + + /** + * Predicts for a single data point using the weighted sum of ensemble predictions. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + private def predictBySumming(features: Vector): Double = { + val treePredictions = trees.map(_.predict(features)) + blas.ddot(numTrees, treePredictions, 1, treeWeights, 1) + } + + /** + * Classifies a single data point based on (weighted) majority votes. + */ + private def predictByVoting(features: Vector): Double = { + val votes = mutable.Map.empty[Int, Double] + trees.view.zip(treeWeights).foreach { case (tree, weight) => + val prediction = tree.predict(features).toInt + votes(prediction) = votes.getOrElse(prediction, 0.0) + weight + } + votes.maxBy(_._2)._1 + } + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return predicted category from the trained model + */ + def predict(features: Vector): Double = { + (algo, combiningStrategy) match { + case (Regression, Sum) => + predictBySumming(features) + case (Regression, Average) => + predictBySumming(features) / sumWeights + case (Classification, Sum) => // binary classification + val prediction = predictBySumming(features) + // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info. + if (prediction > 0.0) 1.0 else 0.0 + case (Classification, Vote) => + predictByVoting(features) + case _ => + throw new IllegalArgumentException( + "TreeEnsembleModel given unsupported (algo, combiningStrategy) combination: " + + s"($algo, $combiningStrategy).") + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x)) + + /** + * Java-friendly version of [[org.apache.spark.mllib.tree.model.TreeEnsembleModel#predict]]. + */ + def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = { + predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] + } + + /** + * Print a summary of the model. + */ + override def toString: String = { + algo match { + case Classification => + s"TreeEnsembleModel classifier with $numTrees trees\n" + case Regression => + s"TreeEnsembleModel regressor with $numTrees trees\n" + case _ => throw new IllegalArgumentException( + s"TreeEnsembleModel given unknown algo parameter: $algo.") + } + } + + /** + * Print the full model to a string. + */ + def toDebugString: String = { + val header = toString + "\n" + header + trees.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + + /** + * Get number of trees in forest. + */ + def numTrees: Int = trees.size + + /** + * Get total number of nodes, summed over all trees in the forest. + */ + def totalNumNodes: Int = trees.map(_.numNodes).sum +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index ca35100aa99c6..9353351af72a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.util.random.BernoulliSampler +import org.apache.spark.util.random.BernoulliCellSampler import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.storage.StorageLevel @@ -76,7 +76,7 @@ object MLUtils { .map { line => val items = line.split(' ') val label = items.head.toDouble - val (indices, values) = items.tail.map { item => + val (indices, values) = items.tail.filter(_.nonEmpty).map { item => val indexAndValue = item.split(':') val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based. val value = indexAndValue(1).toDouble @@ -196,8 +196,8 @@ object MLUtils { /** * Load labeled data from a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. + * L, f1 f2 ... + * where f1, f2 are feature values in Double and L is the corresponding label as Double. * * @param sc SparkContext * @param dir Directory to the input data files. @@ -219,8 +219,8 @@ object MLUtils { /** * Save labeled data to a file. The data format used here is - * , ... - * where , are feature values in Double and is the corresponding label as Double. + * L, f1 f2 ... + * where f1, f2 are feature values in Double and L is the corresponding label as Double. * * @param data An RDD of LabeledPoints containing data to be saved. * @param dir Directory to save the data. @@ -244,7 +244,7 @@ object MLUtils { def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = { val numFoldsF = numFolds.toFloat (1 to numFolds).map { fold => - val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, + val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF, complement = false) val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed) val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed) diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java new file mode 100644 index 0000000000000..42846677ed285 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.feature.StandardScaler; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +/** + * Test Pipeline construction and fitting in Java. + */ +public class JavaPipelineSuite { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaPipelineSuite"); + jsql = new JavaSQLContext(jsc); + JavaRDD points = + jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); + dataset = jsql.applySchema(points, LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void pipeline() { + StandardScaler scaler = new StandardScaler() + .setInputCol("features") + .setOutputCol("scaledFeatures"); + LogisticRegression lr = new LogisticRegression() + .setFeaturesCol("scaledFeatures"); + Pipeline pipeline = new Pipeline() + .setStages(new PipelineStage[] {scaler, lr}); + PipelineModel model = pipeline.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java new file mode 100644 index 0000000000000..76eb7f00329f2 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaLogisticRegressionSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); + jsql = new JavaSQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void logisticRegression() { + LogisticRegression lr = new LogisticRegression(); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset).registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionWithSetters() { + LogisticRegression lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0); + LogisticRegressionModel model = lr.fit(dataset); + model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold + .registerTempTable("prediction"); + JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collect(); + } + + @Test + public void logisticRegressionFitWithVarargs() { + LogisticRegression lr = new LogisticRegression(); + lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0)); + } +} diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java new file mode 100644 index 0000000000000..a266ebd2071a1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning; + +import java.io.Serializable; +import java.util.List; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.ml.classification.LogisticRegression; +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; +import org.apache.spark.ml.param.ParamMap; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite + .generateLogisticInputAsList; + +public class JavaCrossValidatorSuite implements Serializable { + + private transient JavaSparkContext jsc; + private transient JavaSQLContext jsql; + private transient JavaSchemaRDD dataset; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); + jsql = new JavaSQLContext(jsc); + List points = generateLogisticInputAsList(1.0, 1.0, 100, 42); + dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void crossValidationWithLogisticRegression() { + LogisticRegression lr = new LogisticRegression(); + ParamMap[] lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam(), new double[] {0.001, 1000.0}) + .addGrid(lr.maxIter(), new int[] {0, 10}) + .build(); + BinaryClassificationEvaluator eval = new BinaryClassificationEvaluator(); + CrossValidator cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3); + CrossValidatorModel cvModel = cv.fit(dataset); + ParamMap bestParamMap = cvModel.bestModel().fittingParamMap(); + Assert.assertEquals(0.001, bestParamMap.apply(lr.regParam())); + Assert.assertEquals(10, bestParamMap.apply(lr.maxIter())); + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index f6ca9643227f8..af688c504cf1e 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -23,13 +23,14 @@ import scala.Tuple2; import scala.Tuple3; +import com.google.common.collect.Lists; import org.jblas.DoubleMatrix; - import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -47,61 +48,48 @@ public void tearDown() { sc = null; } - static void validatePrediction( + void validatePrediction( MatrixFactorizationModel model, int users, int products, - int features, DoubleMatrix trueRatings, double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { - DoubleMatrix predictedU = new DoubleMatrix(users, features); - List> userFeatures = model.userFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (Tuple2 userFeature : userFeatures) { - predictedU.put((Integer)userFeature._1(), i, userFeature._2()[i]); - } - } - DoubleMatrix predictedP = new DoubleMatrix(products, features); - - List> productFeatures = - model.productFeatures().toJavaRDD().collect(); - for (int i = 0; i < features; ++i) { - for (Tuple2 productFeature : productFeatures) { - predictedP.put((Integer)productFeature._1(), i, productFeature._2()[i]); + List> localUsersProducts = + Lists.newArrayListWithCapacity(users * products); + for (int u=0; u < users; ++u) { + for (int p=0; p < products; ++p) { + localUsersProducts.add(new Tuple2(u, p)); } } - - DoubleMatrix predictedRatings = predictedU.mmul(predictedP.transpose()); - + JavaPairRDD usersProducts = sc.parallelizePairs(localUsersProducts); + List predictedRatings = model.predict(usersProducts).collect(); + Assert.assertEquals(users * products, predictedRatings.size()); if (!implicitPrefs) { - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double correct = trueRatings.get(u, p); - Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", - prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); - } + for (Rating r: predictedRatings) { + double prediction = r.rating(); + double correct = trueRatings.get(r.user(), r.product()); + Assert.assertTrue(String.format("Prediction=%2.4f not below match threshold of %2.2f", + prediction, matchThreshold), Math.abs(prediction - correct) < matchThreshold); } } else { // For implicit prefs we use the confidence-weighted RMSE to test // (ref Mahout's implicit ALS tests) double sqErr = 0.0; double denom = 0.0; - for (int u = 0; u < users; ++u) { - for (int p = 0; p < products; ++p) { - double prediction = predictedRatings.get(u, p); - double truePref = truePrefs.get(u, p); - double confidence = 1.0 + /* alpha = */ 1.0 * Math.abs(trueRatings.get(u, p)); - double err = confidence * (truePref - prediction) * (truePref - prediction); - sqErr += err; - denom += confidence; - } + for (Rating r: predictedRatings) { + double prediction = r.rating(); + double truePref = truePrefs.get(r.user(), r.product()); + double confidence = 1.0 + + /* alpha = */ 1.0 * Math.abs(trueRatings.get(r.user(), r.product())); + double err = confidence * (truePref - prediction) * (truePref - prediction); + sqErr += err; + denom += confidence; } double rmse = Math.sqrt(sqErr / denom); Assert.assertTrue(String.format("Confidence-weighted RMSE=%2.4f above threshold of %2.2f", - rmse, matchThreshold), rmse < matchThreshold); + rmse, matchThreshold), rmse < matchThreshold); } } @@ -116,7 +104,7 @@ public void runALSUsingStaticMethods() { JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.train(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @Test @@ -132,8 +120,8 @@ public void runALSUsingConstructor() { MatrixFactorizationModel model = new ALS().setRank(features) .setIterations(iterations) - .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.3, false, testData._3()); + .run(data); + validatePrediction(model, users, products, testData._2(), 0.3, false, testData._3()); } @Test @@ -147,7 +135,7 @@ public void runImplicitALSUsingStaticMethods() { JavaRDD data = sc.parallelize(testData._1()); MatrixFactorizationModel model = ALS.trainImplicit(data.rdd(), features, iterations); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test @@ -165,7 +153,7 @@ public void runImplicitALSUsingConstructor() { .setIterations(iterations) .setImplicitPrefs(true) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test @@ -183,7 +171,7 @@ public void runImplicitALSWithNegativeWeight() { .setImplicitPrefs(true) .setSeed(8675309L) .run(data.rdd()); - validatePrediction(model, users, products, features, testData._2(), 0.4, true, testData._3()); + validatePrediction(model, users, products, testData._2(), 0.4, true, testData._3()); } @Test diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 2c281a1ee7157..9925aae441af9 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -74,7 +74,7 @@ public void runDTUsingConstructor() { maxBins, categoricalFeaturesInfo); DecisionTree learner = new DecisionTree(strategy); - DecisionTreeModel model = learner.train(rdd.rdd()); + DecisionTreeModel model = learner.run(rdd.rdd()); int numCorrect = validatePrediction(arr, model); Assert.assertTrue(numCorrect == rdd.count()); diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala new file mode 100644 index 0000000000000..4515084bc7ae9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.mockito.Matchers.{any, eq => meq} +import org.mockito.Mockito.when +import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar.mock + +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.sql.SchemaRDD + +class PipelineSuite extends FunSuite { + + abstract class MyModel extends Model[MyModel] + + test("pipeline") { + val estimator0 = mock[Estimator[MyModel]] + val model0 = mock[MyModel] + val transformer1 = mock[Transformer] + val estimator2 = mock[Estimator[MyModel]] + val model2 = mock[MyModel] + val transformer3 = mock[Transformer] + val dataset0 = mock[SchemaRDD] + val dataset1 = mock[SchemaRDD] + val dataset2 = mock[SchemaRDD] + val dataset3 = mock[SchemaRDD] + val dataset4 = mock[SchemaRDD] + + when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) + when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) + when(model0.parent).thenReturn(estimator0) + when(transformer1.transform(meq(dataset1), any[ParamMap])).thenReturn(dataset2) + when(estimator2.fit(meq(dataset2), any[ParamMap]())).thenReturn(model2) + when(model2.transform(meq(dataset2), any[ParamMap]())).thenReturn(dataset3) + when(model2.parent).thenReturn(estimator2) + when(transformer3.transform(meq(dataset3), any[ParamMap]())).thenReturn(dataset4) + + val pipeline = new Pipeline() + .setStages(Array(estimator0, transformer1, estimator2, transformer3)) + val pipelineModel = pipeline.fit(dataset0) + + assert(pipelineModel.stages.size === 4) + assert(pipelineModel.stages(0).eq(model0)) + assert(pipelineModel.stages(1).eq(transformer1)) + assert(pipelineModel.stages(2).eq(model2)) + assert(pipelineModel.stages(3).eq(transformer3)) + + assert(pipelineModel.getModel(estimator0).eq(model0)) + assert(pipelineModel.getModel(estimator2).eq(model2)) + intercept[NoSuchElementException] { + pipelineModel.getModel(mock[Estimator[MyModel]]) + } + val output = pipelineModel.transform(dataset0) + assert(output.eq(dataset4)) + } + + test("pipeline with duplicate stages") { + val estimator = mock[Estimator[MyModel]] + val pipeline = new Pipeline() + .setStages(Array(estimator, estimator)) + val dataset = mock[SchemaRDD] + intercept[IllegalArgumentException] { + pipeline.fit(dataset) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala new file mode 100644 index 0000000000000..e8030fef55b1d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{SQLContext, SchemaRDD} + +class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { + + @transient var sqlContext: SQLContext = _ + @transient var dataset: SchemaRDD = _ + + override def beforeAll(): Unit = { + super.beforeAll() + sqlContext = new SQLContext(sc) + dataset = sqlContext.createSchemaRDD( + sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + } + + test("logistic regression") { + val sqlContext = this.sqlContext + import sqlContext._ + val lr = new LogisticRegression + val model = lr.fit(dataset) + model.transform(dataset) + .select('label, 'prediction) + .collect() + } + + test("logistic regression with setters") { + val sqlContext = this.sqlContext + import sqlContext._ + val lr = new LogisticRegression() + .setMaxIter(10) + .setRegParam(1.0) + val model = lr.fit(dataset) + model.transform(dataset, model.threshold -> 0.8) // overwrite threshold + .select('label, 'score, 'prediction) + .collect() + } + + test("logistic regression fit and transform with varargs") { + val sqlContext = this.sqlContext + import sqlContext._ + val lr = new LogisticRegression + val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) + model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") + .select('label, 'probability, 'prediction) + .collect() + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala new file mode 100644 index 0000000000000..1ce2987612378 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +import org.scalatest.FunSuite + +class ParamsSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param") { + assert(maxIter.name === "maxIter") + assert(maxIter.doc === "max number of iterations") + assert(maxIter.defaultValue.get === 100) + assert(maxIter.parent.eq(solver)) + assert(maxIter.toString === "maxIter: max number of iterations (default: 100)") + assert(inputCol.defaultValue === None) + } + + test("param pair") { + val pair0 = maxIter -> 5 + val pair1 = maxIter.w(5) + val pair2 = ParamPair(maxIter, 5) + for (pair <- Seq(pair0, pair1, pair2)) { + assert(pair.param.eq(maxIter)) + assert(pair.value === 5) + } + } + + test("param map") { + val map0 = ParamMap.empty + + assert(!map0.contains(maxIter)) + assert(map0(maxIter) === maxIter.defaultValue.get) + map0.put(maxIter, 10) + assert(map0.contains(maxIter)) + assert(map0(maxIter) === 10) + + assert(!map0.contains(inputCol)) + intercept[NoSuchElementException] { + map0(inputCol) + } + map0.put(inputCol -> "input") + assert(map0.contains(inputCol)) + assert(map0(inputCol) === "input") + + val map1 = map0.copy + val map2 = ParamMap(maxIter -> 10, inputCol -> "input") + val map3 = new ParamMap() + .put(maxIter, 10) + .put(inputCol, "input") + val map4 = ParamMap.empty ++ map0 + val map5 = ParamMap.empty + map5 ++= map0 + + for (m <- Seq(map1, map2, map3, map4, map5)) { + assert(m.contains(maxIter)) + assert(m(maxIter) === 10) + assert(m.contains(inputCol)) + assert(m(inputCol) === "input") + } + } + + test("params") { + val params = solver.params + assert(params.size === 2) + assert(params(0).eq(inputCol), "params must be ordered by name") + assert(params(1).eq(maxIter)) + assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n")) + assert(solver.getParam("inputCol").eq(inputCol)) + assert(solver.getParam("maxIter").eq(maxIter)) + intercept[NoSuchMethodException] { + solver.getParam("abc") + } + assert(!solver.isSet(inputCol)) + intercept[IllegalArgumentException] { + solver.validate() + } + solver.validate(ParamMap(inputCol -> "input")) + solver.setInputCol("input") + assert(solver.isSet(inputCol)) + assert(solver.getInputCol === "input") + solver.validate() + intercept[IllegalArgumentException] { + solver.validate(ParamMap(maxIter -> -10)) + } + solver.setMaxIter(-10) + intercept[IllegalArgumentException] { + solver.validate() + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala new file mode 100644 index 0000000000000..1a65883d78a71 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.param + +/** A subclass of Params for testing. */ +class TestParams extends Params { + + val maxIter = new IntParam(this, "maxIter", "max number of iterations", Some(100)) + def setMaxIter(value: Int): this.type = { set(maxIter, value); this } + def getMaxIter: Int = get(maxIter) + + val inputCol = new Param[String](this, "inputCol", "input column name") + def setInputCol(value: String): this.type = { set(inputCol, value); this } + def getInputCol: String = get(inputCol) + + override def validate(paramMap: ParamMap) = { + val m = this.paramMap ++ paramMap + require(m(maxIter) >= 0) + require(m.contains(inputCol)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala new file mode 100644 index 0000000000000..41cc13da4d5b1 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import org.scalatest.FunSuite + +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator +import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{SQLContext, SchemaRDD} + +class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { + + @transient var dataset: SchemaRDD = _ + + override def beforeAll(): Unit = { + super.beforeAll() + val sqlContext = new SQLContext(sc) + dataset = sqlContext.createSchemaRDD( + sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) + } + + test("cross validation with logistic regression") { + val lr = new LogisticRegression + val lrParamMaps = new ParamGridBuilder() + .addGrid(lr.regParam, Array(0.001, 1000.0)) + .addGrid(lr.maxIter, Array(0, 10)) + .build() + val eval = new BinaryClassificationEvaluator + val cv = new CrossValidator() + .setEstimator(lr) + .setEstimatorParamMaps(lrParamMaps) + .setEvaluator(eval) + .setNumFolds(3) + val cvModel = cv.fit(dataset) + val bestParamMap = cvModel.bestModel.fittingParamMap + assert(bestParamMap(lr.regParam) === 0.001) + assert(bestParamMap(lr.maxIter) === 10) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala new file mode 100644 index 0000000000000..20aa100112bfe --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/ParamGridBuilderSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tuning + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.ml.param.{ParamMap, TestParams} + +class ParamGridBuilderSuite extends FunSuite { + + val solver = new TestParams() + import solver.{inputCol, maxIter} + + test("param grid builder") { + def validateGrid(maps: Array[ParamMap], expected: mutable.Set[(Int, String)]): Unit = { + assert(maps.size === expected.size) + maps.foreach { m => + val tuple = (m(maxIter), m(inputCol)) + assert(expected.contains(tuple)) + expected.remove(tuple) + } + assert(expected.isEmpty) + } + + val maps0 = new ParamGridBuilder() + .baseOn(maxIter -> 10) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected0 = mutable.Set( + (10, "input0"), + (10, "input1")) + validateGrid(maps0, expected0) + + val maps1 = new ParamGridBuilder() + .baseOn(ParamMap(maxIter -> 5, inputCol -> "input")) // will be overwritten + .addGrid(maxIter, Array(10, 20)) + .addGrid(inputCol, Array("input0", "input1")) + .build() + val expected1 = mutable.Set( + (10, "input0"), + (20, "input0"), + (10, "input1"), + (20, "input1")) + validateGrid(maps1, expected1) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index e954baaf7d91e..4e812994405b3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ object LogisticRegressionSuite { @@ -57,7 +57,7 @@ object LogisticRegressionSuite { } } -class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers { +class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { def validatePrediction( predictions: Seq[Double], input: Seq[LabeledPoint], @@ -80,13 +80,16 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val testRDD = sc.parallelize(testData, 2) testRDD.cache() val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer.setStepSize(10.0).setNumIterations(20) + lr.optimizer + .setStepSize(10.0) + .setRegParam(0.0) + .setNumIterations(20) val model = lr.run(testRDD) // Test the weights - assert(model.weights(0) ~== -1.52 relTol 0.01) - assert(model.intercept ~== 2.00 relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -112,10 +115,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD) // Test the weights - assert(model.weights(0) ~== -1.52 relTol 0.01) - assert(model.intercept ~== 2.00 relTol 0.01) - assert(model.weights(0) ~== model.weights(0) relTol 0.01) - assert(model.intercept ~== model.intercept relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -141,13 +142,16 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match // Use half as many iterations as the previous test. val lr = new LogisticRegressionWithSGD().setIntercept(true) - lr.optimizer.setStepSize(10.0).setNumIterations(10) + lr.optimizer + .setStepSize(10.0) + .setRegParam(0.0) + .setNumIterations(10) val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -1.50 relTol 0.01) - assert(model.intercept ~== 1.97 relTol 0.01) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) @@ -212,8 +216,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match val model = lr.run(testRDD, initialWeights) // Test the weights - assert(model.weights(0) ~== -1.50 relTol 0.02) - assert(model.intercept ~== 1.97 relTol 0.02) + assert(model.weights(0) ~== B relTol 0.02) + assert(model.intercept ~== A relTol 0.02) val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 80989bc074e84..e68fe89d6ccea 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} object NaiveBayesSuite { @@ -60,7 +60,7 @@ object NaiveBayesSuite { } } -class NaiveBayesSuite extends FunSuite with LocalSparkContext { +class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOfPredictions = predictions.zip(input).count { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 65e5df58db4c7..a2de7fbd41383 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -26,7 +26,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} object SVMSuite { @@ -58,7 +58,7 @@ object SVMSuite { } -class SVMSuite extends FunSuite with LocalSparkContext { +class SVMSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index afa1f79b95a12..9ebef8466c831 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -22,10 +22,10 @@ import scala.util.Random import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Vector, Vectors} -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class KMeansSuite extends FunSuite with LocalSparkContext { +class KMeansSuite extends FunSuite with MLlibTestSparkContext { import org.apache.spark.mllib.clustering.KMeans.{K_MEANS_PARALLEL, RANDOM} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala new file mode 100644 index 0000000000000..850c9fce507cd --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.random.XORShiftRandom + +class StreamingKMeansSuite extends FunSuite with TestSuiteBase { + + override def maxWaitTimeMillis = 30000 + + test("accuracy for single center and equivalence to grand average") { + // set parameters + val numBatches = 10 + val numPoints = 50 + val k = 1 + val d = 5 + val r = 0.1 + + // create model with one cluster + val model = new StreamingKMeans() + .setK(1) + .setDecayFactor(1.0) + .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0)) + + // generate random data for k-means + val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // estimated center should be close to true center + assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1) + + // estimated center from streaming should exactly match the arithmetic mean of all data points + // because the decay factor is set to 1.0 + val grandMean = + input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble + assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5) + } + + test("accuracy for two centers") { + val numBatches = 10 + val numPoints = 5 + val k = 2 + val d = 5 + val r = 0.1 + + // create model with two clusters + val kMeans = new StreamingKMeans() + .setK(2) + .setHalfLife(2, "batches") + .setInitialCenters( + Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1), + Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)), + Array(5.0, 5.0)) + + // generate random data for k-means + val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + kMeans.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check that estimated centers are close to true centers + // NOTE exact assignment depends on the initialization! + assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1) + assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1) + } + + test("detecting dying clusters") { + val numBatches = 10 + val numPoints = 5 + val k = 1 + val d = 1 + val r = 1.0 + + // create model with two clusters + val kMeans = new StreamingKMeans() + .setK(2) + .setHalfLife(0.5, "points") + .setInitialCenters( + Array(Vectors.dense(0.0), Vectors.dense(1000.0)), + Array(1.0, 1.0)) + + // new data are all around the first cluster 0.0 + val (input, _) = + StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0))) + + // setup and run the model training + val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => { + kMeans.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // check that estimated centers are close to true centers + // NOTE exact assignment depends on the initialization! + val model = kMeans.latestModel() + val c0 = model.clusterCenters(0)(0) + val c1 = model.clusterCenters(1)(0) + + assert(c0 * c1 < 0.0, "should have one positive center and one negative center") + // 0.8 is the mean of half-normal distribution + assert(math.abs(c0) ~== 0.8 absTol 0.6) + assert(math.abs(c1) ~== 0.8 absTol 0.6) + } + + def StreamingKMeansDataGenerator( + numPoints: Int, + numBatches: Int, + k: Int, + d: Int, + r: Double, + seed: Int, + initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = { + val rand = new XORShiftRandom(seed) + val centers = initCenters match { + case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian()))) + case _ => initCenters + } + val data = (0 until numBatches).map { i => + (0 until numPoints).map { idx => + val center = centers(idx % k) + Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r)) + } + } + (data, centers) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala index 994e0feb8629e..79847633ff0dc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { +class AreaUnderCurveSuite extends FunSuite with MLlibTestSparkContext { test("auc computation") { val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) val auc = 4.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala index a733f88b60b80..8a18e2971cab6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala @@ -19,44 +19,109 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { +class BinaryClassificationMetricsSuite extends FunSuite with MLlibTestSparkContext { - def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 + private def areWithinEpsilon(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5 - def cond2(x: ((Double, Double), (Double, Double))): Boolean = + private def pairsWithinEpsilon(x: ((Double, Double), (Double, Double))): Boolean = (x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5) + private def assertSequencesMatch(left: Seq[Double], right: Seq[Double]): Unit = { + assert(left.zip(right).forall(areWithinEpsilon)) + } + + private def assertTupleSequencesMatch(left: Seq[(Double, Double)], + right: Seq[(Double, Double)]): Unit = { + assert(left.zip(right).forall(pairsWithinEpsilon)) + } + + private def validateMetrics(metrics: BinaryClassificationMetrics, + expectedThresholds: Seq[Double], + expectedROCCurve: Seq[(Double, Double)], + expectedPRCurve: Seq[(Double, Double)], + expectedFMeasures1: Seq[Double], + expectedFmeasures2: Seq[Double], + expectedPrecisions: Seq[Double], + expectedRecalls: Seq[Double]) = { + + assertSequencesMatch(metrics.thresholds().collect(), expectedThresholds) + assertTupleSequencesMatch(metrics.roc().collect(), expectedROCCurve) + assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(expectedROCCurve) absTol 1E-5) + assertTupleSequencesMatch(metrics.pr().collect(), expectedPRCurve) + assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(expectedPRCurve) absTol 1E-5) + assertTupleSequencesMatch(metrics.fMeasureByThreshold().collect(), + expectedThresholds.zip(expectedFMeasures1)) + assertTupleSequencesMatch(metrics.fMeasureByThreshold(2.0).collect(), + expectedThresholds.zip(expectedFmeasures2)) + assertTupleSequencesMatch(metrics.precisionByThreshold().collect(), + expectedThresholds.zip(expectedPrecisions)) + assertTupleSequencesMatch(metrics.recallByThreshold().collect(), + expectedThresholds.zip(expectedRecalls)) + } + test("binary evaluation metrics") { val scoreAndLabels = sc.parallelize( Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) val metrics = new BinaryClassificationMetrics(scoreAndLabels) - val threshold = Seq(0.8, 0.6, 0.4, 0.1) + val thresholds = Seq(0.8, 0.6, 0.4, 0.1) val numTruePositives = Seq(1, 3, 3, 4) val numFalsePositives = Seq(0, 1, 2, 3) val numPositives = 4 val numNegatives = 3 - val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) => + val precisions = numTruePositives.zip(numFalsePositives).map { case (t, f) => t.toDouble / (t + f) } - val recall = numTruePositives.map(t => t.toDouble / numPositives) + val recalls = numTruePositives.map(t => t.toDouble / numPositives) val fpr = numFalsePositives.map(f => f.toDouble / numNegatives) - val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) - val pr = recall.zip(precision) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) val prCurve = Seq((0.0, 1.0)) ++ pr val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} - assert(metrics.thresholds().collect().zip(threshold).forall(cond1)) - assert(metrics.roc().collect().zip(rocCurve).forall(cond2)) - assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5) - assert(metrics.pr().collect().zip(prCurve).forall(cond2)) - assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5) - assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2)) - assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2)) - assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2)) - assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2)) + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + } + + test("binary evaluation metrics for RDD where all examples have positive label") { + val scoreAndLabels = sc.parallelize(Seq((0.5, 1.0), (0.5, 1.0)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + + val thresholds = Seq(0.5) + val precisions = Seq(1.0) + val recalls = Seq(1.0) + val fpr = Seq(0.0) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)} + val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) + } + + test("binary evaluation metrics for RDD where all examples have negative label") { + val scoreAndLabels = sc.parallelize(Seq((0.5, 0.0), (0.5, 0.0)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + + val thresholds = Seq(0.5) + val precisions = Seq(0.0) + val recalls = Seq(0.0) + val fpr = Seq(1.0) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recalls) ++ Seq((1.0, 1.0)) + val pr = recalls.zip(precisions) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { + case (0, 0) => 0.0 + case (r, p) => 2.0 * (p * r) / (p + r) + } + val f2 = pr.map { + case (0, 0) => 0.0 + case (r, p) => 5.0 * (p * r) / (4.0 * p + r) + } + + validateMetrics(metrics, thresholds, rocCurve, prCurve, f1, f2, precisions, recalls) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index 1ea503971c864..7dc4f3cfbc4e4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.evaluation import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Matrices -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class MulticlassMetricsSuite extends FunSuite with LocalSparkContext { +class MulticlassMetricsSuite extends FunSuite with MLlibTestSparkContext { test("Multiclass evaluation metrics") { /* * Confusion matrix for 3-class classification with total 9 instances: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala new file mode 100644 index 0000000000000..2537dd62c92f2 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD + +class MultilabelMetricsSuite extends FunSuite with MLlibTestSparkContext { + test("Multilabel evaluation metrics") { + /* + * Documents true labels (5x class0, 3x class1, 4x class2): + * doc 0 - predict 0, 1 - class 0, 2 + * doc 1 - predict 0, 2 - class 0, 1 + * doc 2 - predict none - class 0 + * doc 3 - predict 2 - class 2 + * doc 4 - predict 2, 0 - class 2, 0 + * doc 5 - predict 0, 1, 2 - class 0, 1 + * doc 6 - predict 1 - class 1, 2 + * + * predicted classes + * class 0 - doc 0, 1, 4, 5 (total 4) + * class 1 - doc 0, 5, 6 (total 3) + * class 2 - doc 1, 3, 4, 5 (total 4) + * + * true classes + * class 0 - doc 0, 1, 2, 4, 5 (total 5) + * class 1 - doc 1, 5, 6 (total 3) + * class 2 - doc 0, 3, 4, 6 (total 4) + * + */ + val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize( + Seq((Array(0.0, 1.0), Array(0.0, 2.0)), + (Array(0.0, 2.0), Array(0.0, 1.0)), + (Array(), Array(0.0)), + (Array(2.0), Array(2.0)), + (Array(2.0, 0.0), Array(2.0, 0.0)), + (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)), + (Array(1.0), Array(1.0, 2.0))), 2) + val metrics = new MultilabelMetrics(scoreAndLabels) + val delta = 0.00001 + val precision0 = 4.0 / (4 + 0) + val precision1 = 2.0 / (2 + 1) + val precision2 = 2.0 / (2 + 2) + val recall0 = 4.0 / (4 + 1) + val recall1 = 2.0 / (2 + 1) + val recall2 = 2.0 / (2 + 2) + val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0) + val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1) + val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2) + val sumTp = 4 + 2 + 2 + assert(sumTp == (1 + 1 + 0 + 1 + 2 + 2 + 1)) + val microPrecisionClass = sumTp.toDouble / (4 + 0 + 2 + 1 + 2 + 2) + val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2) + val microF1MeasureClass = 2.0 * sumTp.toDouble / + (2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2)) + val macroPrecisionDoc = 1.0 / 7 * + (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0) + val macroRecallDoc = 1.0 / 7 * + (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2) + val macroF1MeasureDoc = (1.0 / 7) * + 2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) + + 2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) ) + val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1) + val strictAccuracy = 2.0 / 7 + val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2) + assert(math.abs(metrics.precision(0.0) - precision0) < delta) + assert(math.abs(metrics.precision(1.0) - precision1) < delta) + assert(math.abs(metrics.precision(2.0) - precision2) < delta) + assert(math.abs(metrics.recall(0.0) - recall0) < delta) + assert(math.abs(metrics.recall(1.0) - recall1) < delta) + assert(math.abs(metrics.recall(2.0) - recall2) < delta) + assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta) + assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta) + assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta) + assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta) + assert(math.abs(metrics.microRecall - microRecallClass) < delta) + assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta) + assert(math.abs(metrics.precision - macroPrecisionDoc) < delta) + assert(math.abs(metrics.recall - macroRecallDoc) < delta) + assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta) + assert(math.abs(metrics.hammingLoss - hammingLoss) < delta) + assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta) + assert(math.abs(metrics.accuracy - accuracy) < delta) + assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0))) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala new file mode 100644 index 0000000000000..609eed983ff4e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class RankingMetricsSuite extends FunSuite with MLlibTestSparkContext { + test("Ranking metrics: map, ndcg") { + val predictionAndLabels = sc.parallelize( + Seq( + (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)), + (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)), + (Array[Int](1, 2, 3, 4, 5), Array[Int]()) + ), 2) + val eps: Double = 1E-5 + + val metrics = new RankingMetrics(predictionAndLabels) + val map = metrics.meanAveragePrecision + + assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps) + assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps) + assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps) + assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps) + assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps) + + assert(map ~== 0.355026 absTol eps) + + assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps) + assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps) + assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps) + assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps) + + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala new file mode 100644 index 0000000000000..670b4c34e6095 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class RegressionMetricsSuite extends FunSuite with MLlibTestSparkContext { + + test("regression metrics") { + val predictionAndObservations = sc.parallelize( + Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch") + } + + test("regression metrics with complete fitting") { + val predictionAndObservations = sc.parallelize( + Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2) + val metrics = new RegressionMetrics(predictionAndObservations) + assert(metrics.explainedVariance ~== 1.0 absTol 1E-5, + "explained variance regression score mismatch") + assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5, + "root mean squared error mismatch") + assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index a599e0d938569..0c4dfb7b97c7f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class HashingTFSuite extends FunSuite with LocalSparkContext { +class HashingTFSuite extends FunSuite with MLlibTestSparkContext { test("hashing tf on a single doc") { val hashingTF = new HashingTF(1000) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 43974f84e3ca8..30147e7fd948f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.FunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class IDFSuite extends FunSuite with LocalSparkContext { +class IDFSuite extends FunSuite with MLlibTestSparkContext { test("idf") { val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala index fb76dccfdf79e..85fdd271b5ed1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/NormalizerSuite.scala @@ -19,11 +19,13 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite +import breeze.linalg.{norm => brzNorm} + import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class NormalizerSuite extends FunSuite with LocalSparkContext { +class NormalizerSuite extends FunSuite with MLlibTestSparkContext { val data = Array( Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), @@ -50,10 +52,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data1(0).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(2).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(3).toBreeze.norm(1) ~== 1.0 absTol 1E-5) - assert(data1(4).toBreeze.norm(1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(0).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(2).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(3).toBreeze, 1) ~== 1.0 absTol 1E-5) + assert(brzNorm(data1(4).toBreeze, 1) ~== 1.0 absTol 1E-5) assert(data1(0) ~== Vectors.sparse(3, Seq((0, -0.465116279), (1, 0.53488372))) absTol 1E-5) assert(data1(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) @@ -77,10 +79,10 @@ class NormalizerSuite extends FunSuite with LocalSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(data2(0).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(2).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(3).toBreeze.norm(2) ~== 1.0 absTol 1E-5) - assert(data2(4).toBreeze.norm(2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(0).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(2).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(3).toBreeze, 2) ~== 1.0 absTol 1E-5) + assert(brzNorm(data2(4).toBreeze, 2) ~== 1.0 absTol 1E-5) assert(data2(0) ~== Vectors.sparse(3, Seq((0, -0.65617871), (1, 0.75460552))) absTol 1E-5) assert(data2(1) ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index e217b93cebbdb..4c93c0ca4f86c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD -class StandardScalerSuite extends FunSuite with LocalSparkContext { +class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { data.treeAggregate(new MultivariateOnlineSummarizer)( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index e34335d89eb75..52278690dbd89 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.feature import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class Word2VecSuite extends FunSuite with LocalSparkContext { +class Word2VecSuite extends FunSuite with MLlibTestSparkContext { // TODO: add more tests diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index 5f8b8c4b72697..322a0e9242918 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -17,7 +17,11 @@ package org.apache.spark.mllib.linalg +import java.util.Random + +import org.mockito.Mockito.when import org.scalatest.FunSuite +import org.scalatest.mock.MockitoSugar._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -112,4 +116,50 @@ class MatricesSuite extends FunSuite { assert(sparseMat(0, 1) === 10.0) assert(sparseMat.values(2) === 10.0) } + + test("zeros") { + val mat = Matrices.zeros(2, 3).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 3) + assert(mat.values.forall(_ == 0.0)) + } + + test("ones") { + val mat = Matrices.ones(2, 3).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 3) + assert(mat.values.forall(_ == 1.0)) + } + + test("eye") { + val mat = Matrices.eye(2).asInstanceOf[DenseMatrix] + assert(mat.numCols === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 1.0)) + } + + test("rand") { + val rng = mock[Random] + when(rng.nextDouble()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = Matrices.rand(2, 2, rng).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } + + test("randn") { + val rng = mock[Random] + when(rng.nextGaussian()).thenReturn(1.0, 2.0, 3.0, 4.0) + val mat = Matrices.randn(2, 2, rng).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0)) + } + + test("diag") { + val mat = Matrices.diag(Vectors.dense(1.0, 2.0)).asInstanceOf[DenseMatrix] + assert(mat.numRows === 2) + assert(mat.numCols === 2) + assert(mat.values.toSeq === Seq(1.0, 0.0, 0.0, 2.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index cd651fe2d2ddf..9492f604af4d5 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.linalg +import breeze.linalg.{DenseMatrix => BDM} import org.scalatest.FunSuite import org.apache.spark.SparkException @@ -155,4 +156,45 @@ class VectorsSuite extends FunSuite { throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.") } } + + test("VectorUDT") { + val dv0 = Vectors.dense(Array.empty[Double]) + val dv1 = Vectors.dense(1.0, 2.0) + val sv0 = Vectors.sparse(2, Array.empty, Array.empty) + val sv1 = Vectors.sparse(2, Array(1), Array(2.0)) + val udt = new VectorUDT() + for (v <- Seq(dv0, dv1, sv0, sv1)) { + assert(v === udt.deserialize(udt.serialize(v))) + } + } + + test("fromBreeze") { + val x = BDM.zeros[Double](10, 10) + val v = Vectors.fromBreeze(x(::, 0)) + assert(v.size === x.rows) + } + + test("foreachActive") { + val dv = Vectors.dense(0.0, 1.2, 3.1, 0.0) + val sv = Vectors.sparse(4, Seq((1, 1.2), (2, 3.1), (3, 0.0))) + + val dvMap = scala.collection.mutable.Map[Int, Double]() + dv.foreachActive { (index, value) => + dvMap.put(index, value) + } + assert(dvMap.size === 4) + assert(dvMap.get(0) === Some(0.0)) + assert(dvMap.get(1) === Some(1.2)) + assert(dvMap.get(2) === Some(3.1)) + assert(dvMap.get(3) === Some(0.0)) + + val svMap = scala.collection.mutable.Map[Int, Double]() + sv.foreachActive { (index, value) => + svMap.put(index, value) + } + assert(svMap.size === 3) + assert(svMap.get(1) === Some(1.2)) + assert(svMap.get(2) === Some(3.1)) + assert(svMap.get(3) === Some(0.0)) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index cd45438fb628f..f8709751efce6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.FunSuite import breeze.linalg.{DenseMatrix => BDM} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors -class CoordinateMatrixSuite extends FunSuite with LocalSparkContext { +class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { val m = 5 val n = 4 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index f7c46f23b746d..e25bc02b06c9a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -21,11 +21,11 @@ import org.scalatest.FunSuite import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrices, Vectors} -class IndexedRowMatrixSuite extends FunSuite with LocalSparkContext { +class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 63f3ed58c0d4d..dbf55ff81ca99 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -23,9 +23,9 @@ import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, s import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} -class RowMatrixSuite extends FunSuite with LocalSparkContext { +class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { val m = 4 val n = 3 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index bf040110e228b..86481c6e66200 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { @@ -61,7 +61,7 @@ object GradientDescentSuite { } } -class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers { +class GradientDescentSuite extends FunSuite with MLlibTestSparkContext with Matchers { test("Assert the loss is decreasing.") { val nPoints = 10000 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index ccba004baa007..70c64775e4c04 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -23,10 +23,10 @@ import org.scalatest.{FunSuite, Matchers} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ -class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers { +class LBFGSSuite extends FunSuite with MLlibTestSparkContext with Matchers { val nPoints = 10000 val A = 2.0 diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala index b781a6aed9a8c..82c327bd49fcd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala @@ -37,6 +37,12 @@ class NNLSSuite extends FunSuite { (ata, atb) } + /** Compute the objective value */ + def computeObjectiveValue(ata: DoubleMatrix, atb: DoubleMatrix, x: DoubleMatrix): Double = { + val res = (x.transpose().mmul(ata).mmul(x)).mul(0.5).sub(atb.dot(x)) + res.get(0) + } + test("NNLS: exact solution cases") { val n = 20 val rand = new Random(12346) @@ -79,4 +85,28 @@ class NNLSSuite extends FunSuite { assert(x(i) >= 0) } } + + test("NNLS: objective value test") { + val n = 5 + val ata = new DoubleMatrix(5, 5 + , 517399.13534, 242529.67289, -153644.98976, 130802.84503, -798452.29283 + , 242529.67289, 126017.69765, -75944.21743, 81785.36128, -405290.60884 + , -153644.98976, -75944.21743, 46986.44577, -45401.12659, 247059.51049 + , 130802.84503, 81785.36128, -45401.12659, 67457.31310, -253747.03819 + , -798452.29283, -405290.60884, 247059.51049, -253747.03819, 1310939.40814 + ) + val atb = new DoubleMatrix(5, 1, + -31755.05710, 13047.14813, -20191.24443, 25993.77580, 11963.55017) + + /** reference solution obtained from matlab function quadprog */ + val refx = new DoubleMatrix(Array(34.90751, 103.96254, 0.00000, 27.82094, 58.79627)) + val refObj = computeObjectiveValue(ata, atb, refx) + + + val ws = NNLS.createWorkspace(n) + val x = new DoubleMatrix(NNLS.solve(ata, atb, ws)) + val obj = computeObjectiveValue(ata, atb, x) + + assert(obj < refObj + 1E-5) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index c50b78bcbcc61..ea5889b3ecd5e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.FunSuite import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.util.StatCounter @@ -34,7 +34,7 @@ import org.apache.spark.util.StatCounter * * TODO update tests to use TestingUtils for floating point comparison after PR 1367 is merged */ -class RandomRDDsSuite extends FunSuite with LocalSparkContext with Serializable { +class RandomRDDsSuite extends FunSuite with MLlibTestSparkContext with Serializable { def testGeneratedRDD(rdd: RDD[Double], expectedSize: Long, diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 27a19f793242b..681ce9263933b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.rdd import org.scalatest.FunSuite -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ -class RDDFunctionsSuite extends FunSuite with LocalSparkContext { +class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { test("sliding") { val data = 0 until 6 @@ -42,9 +42,9 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) val rdd = sc.parallelize(data, data.length).flatMap(s => s) assert(rdd.partitions.size === data.length) - val sliding = rdd.sliding(3) - val expected = data.flatMap(x => x).sliding(3).toList - assert(sliding.collect().toList === expected) + val sliding = rdd.sliding(3).collect().toSeq.map(_.toSeq) + val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) + assert(sliding === expected) } test("treeAggregate") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala index 017c39edb185f..603d0ad127b86 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.FunSuite import org.jblas.DoubleMatrix import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.recommendation.ALS.BlockStats object ALSSuite { @@ -85,7 +85,7 @@ object ALSSuite { } -class ALSSuite extends FunSuite with LocalSparkContext { +class ALSSuite extends FunSuite with MLlibTestSparkContext { test("rank-1 matrices") { testALS(50, 100, 1, 15, 0.7, 0.3) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala new file mode 100644 index 0000000000000..b9caecc904a23 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.recommendation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.rdd.RDD + +class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { + + val rank = 2 + var userFeatures: RDD[(Int, Array[Double])] = _ + var prodFeatures: RDD[(Int, Array[Double])] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + userFeatures = sc.parallelize(Seq((0, Array(1.0, 2.0)), (1, Array(3.0, 4.0)))) + prodFeatures = sc.parallelize(Seq((2, Array(5.0, 6.0)))) + } + + test("constructor") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + assert(model.predict(0, 2) ~== 17.0 relTol 1e-14) + + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(1, userFeatures, prodFeatures) + } + + val userFeatures1 = sc.parallelize(Seq((0, Array(1.0)), (1, Array(3.0)))) + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(rank, userFeatures1, prodFeatures) + } + + val prodFeatures1 = sc.parallelize(Seq((2, Array(5.0)))) + intercept[IllegalArgumentException] { + new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 7aa96421aed87..2668dcc14a842 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -23,9 +23,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, - LocalSparkContext} + MLlibTestSparkContext} -class LassoSuite extends FunSuite with LocalSparkContext { +class LassoSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 4f89112b650c5..864622a9296a6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -23,9 +23,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, - LocalSparkContext} + MLlibTestSparkContext} -class LinearRegressionSuite extends FunSuite with LocalSparkContext { +class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext { def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 727bbd051ff15..18d3bf5ea4eca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -24,9 +24,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, - LocalSparkContext} + MLlibTestSparkContext} -class RidgeRegressionSuite extends FunSuite with LocalSparkContext { +class RidgeRegressionSuite extends FunSuite with MLlibTestSparkContext { def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = { predictions.zip(input).map { case (prediction, expected) => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala index 34548c86ebc14..d20a09b4b4925 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/CorrelationSuite.scala @@ -24,9 +24,9 @@ import breeze.linalg.{DenseMatrix => BDM, Matrix => BM} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.correlation.{Correlations, PearsonCorrelation, SpearmanCorrelation} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class CorrelationSuite extends FunSuite with LocalSparkContext { +class CorrelationSuite extends FunSuite with MLlibTestSparkContext { // test input data val xData = Array(1.0, 0.0, -2.0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala index 6de3840b3f198..15418e6035965 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala @@ -25,10 +25,10 @@ import org.apache.spark.SparkException import org.apache.spark.mllib.linalg.{DenseVector, Matrices, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.test.ChiSqTest -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class HypothesisTestSuite extends FunSuite with LocalSparkContext { +class HypothesisTestSuite extends FunSuite with MLlibTestSparkContext { test("chi squared pearson goodness of fit") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala index 1e9415249104b..23b0eec865de6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizerSuite.scala @@ -208,4 +208,15 @@ class MultivariateOnlineSummarizerSuite extends FunSuite { assert(summarizer2.variance ~== Vectors.dense(0, 0, 0) absTol 1E-5, "variance mismatch") } + + test("merging summarizer when one side has zero mean (SPARK-4355)") { + val s0 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(2.0)) + .add(Vectors.dense(2.0)) + val s1 = new MultivariateOnlineSummarizer() + .add(Vectors.dense(1.0)) + .add(Vectors.dense(-1.0)) + s0.merge(s1) + assert(s0.mean(0) ~== 1.0 absTol 1e-14) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index a48ed71a1c5fc..972c905ec9ffa 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -26,13 +26,13 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} -import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.util.MLlibTestSparkContext -class DecisionTreeSuite extends FunSuite with LocalSparkContext { +class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext { test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() @@ -102,6 +102,72 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } + test("find splits for a continuous feature") { + // find splits for normal case + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(6), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array.fill(200000)(math.random) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 5) + assert(fakeMetadata.numSplits(0) === 5) + assert(fakeMetadata.numBins(0) === 6) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits should not return identical splits + // when there are not enough split candidates, reduce the number of splits in metadata + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(5), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 3) + assert(fakeMetadata.numSplits(0) === 3) + assert(fakeMetadata.numBins(0) === 4) + // check returned splits are distinct + assert(splits.distinct.length === splits.length) + } + + // find splits when most samples close to the minimum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 2) + assert(fakeMetadata.numSplits(0) === 2) + assert(fakeMetadata.numBins(0) === 3) + assert(splits(0) === 2.0) + assert(splits(1) === 3.0) + } + + // find splits when most samples close to the maximum + { + val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0, + Map(), Set(), + Array(3), Gini, QuantileStrategy.Sort, + 0, 0, 0.0, 0, 0 + ) + val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble) + val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0) + assert(splits.length === 1) + assert(fakeMetadata.numSplits(0) === 1) + assert(fakeMetadata.numBins(0) === 2) + assert(splits(0) === 1.0) + } + } + test("Multiclass classification with unordered categorical features:" + " split and bin calculations") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() @@ -253,7 +319,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) assert(stats.impurity > 0.2) } @@ -282,7 +348,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val stats = rootNode.stats.get assert(stats.gain > 0) - assert(rootNode.predict === 0.6) + assert(rootNode.predict.predict === 0.6) assert(stats.impurity > 0.2) } @@ -352,7 +418,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Binary classification stump with fixed label 0 for Entropy") { @@ -377,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 0) + assert(rootNode.predict.predict === 0) } test("Binary classification stump with fixed label 1 for Entropy") { @@ -402,7 +468,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.gain === 0) assert(stats.leftImpurity === 0) assert(stats.rightImpurity === 0) - assert(rootNode.predict === 1) + assert(rootNode.predict.predict === 1) } test("Second level node building with vs. without groups") { @@ -427,7 +493,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(rootNode1.rightNode.nonEmpty) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) // Single group second level tree construction. val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) @@ -471,7 +537,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats1.impurity === stats2.impurity) assert(stats1.leftImpurity === stats2.leftImpurity) assert(stats1.rightImpurity === stats2.rightImpurity) - assert(children1(i).predict === children2(i).predict) + assert(children1(i).predict.predict === children2(i).predict.predict) } } @@ -646,7 +712,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(rdd, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = rdd.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -693,7 +759,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val model = DecisionTree.train(input, strategy) assert(model.topNode.isLeaf) - assert(model.topNode.predict == 0.0) + assert(model.topNode.predict.predict == 0.0) val predicts = input.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) @@ -705,6 +771,92 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) } + + test("Avoid aggregation on the last level") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue leaf nodes into node queue + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } + + test("Avoid aggregation if impurity is 0.0") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)) + arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 2, categoricalFeaturesInfo = Map(0 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) + val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) + + val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) + val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false) + + val topNode = Node.emptyNode(nodeIndex = 1) + assert(topNode.predict.predict === Double.MinValue) + assert(topNode.impurity === -1.0) + assert(topNode.isLeaf === false) + + val nodesForGroup = Map((0, Array(topNode))) + val treeToNodeToIndexInfo = Map((0, Map( + (topNode.id, new RandomForest.NodeIndexInfo(0, None)) + ))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) + + // don't enqueue a node into node queue if its impurity is 0.0 + assert(nodeQueue.isEmpty) + + // set impurity and predict for topNode + assert(topNode.predict.predict !== Double.MinValue) + assert(topNode.impurity !== -1.0) + + // set impurity and predict for child nodes + assert(topNode.leftNode.get.predict.predict === 0.0) + assert(topNode.rightNode.get.predict.predict === 1.0) + assert(topNode.leftNode.get.impurity === 0.0) + assert(topNode.rightNode.get.impurity === 0.0) + } } object DecisionTreeSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala new file mode 100644 index 0000000000000..8972c229b7ecb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.model.TreeEnsembleModel +import org.apache.spark.util.StatCounter + +import scala.collection.mutable + +object EnsembleTestHelper { + + /** + * Aggregates all values in data, and tests whether the empirical mean and stddev are within + * epsilon of the expected values. + * @param data Every element of the data should be an i.i.d. sample from some distribution. + */ + def testRandomArrays( + data: Array[Array[Double]], + numCols: Int, + expectedMean: Double, + expectedStddev: Double, + epsilon: Double) { + val values = new mutable.ArrayBuffer[Double]() + data.foreach { row => + assert(row.size == numCols) + values ++= row + } + val stats = new StatCounter(values) + assert(math.abs(stats.mean - expectedMean) < epsilon) + assert(math.abs(stats.stdev - expectedStddev) < epsilon) + } + + def validateClassifier( + model: TreeEnsembleModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } + + /** + * Validates a tree ensemble model for regression. + */ + def validateRegressor( + model: TreeEnsembleModel, + input: Seq[LabeledPoint], + required: Double, + metricName: String = "mse") { + val predictions = input.map(x => model.predict(x.features)) + val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) => + prediction - label + } + val metric = metricName match { + case "mse" => + errors.map(err => err * err).sum / errors.size + case "mae" => + errors.map(math.abs).sum / errors.size + } + + assert(metric <= required, + s"validateRegressor calculated $metricName $metric but required $required.") + } + + def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](numInstances) + for (i <- 0 until numInstances) { + val label = if (i < numInstances / 10) { + 0.0 + } else if (i < numInstances / 2) { + 1.0 + } else if (i < numInstances * 0.9) { + 0.0 + } else { + 1.0 + } + val features = Array.fill[Double](numFeatures)(i.toDouble) + arr(i) = new LabeledPoint(label, Vectors.dense(features)) + } + arr + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala new file mode 100644 index 0000000000000..d4d54cf4c9e2a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.impurity.Variance +import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} + +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[GradientBoostedTrees]]. + */ +class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { + + test("Regression with continuous features: SquaredError") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed => + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06) + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e + } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + } + + test("Regression with continuous features: Absolute Error") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, AbsoluteError, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.85, "mae") + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e + } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val dt = DecisionTree.train(remappedInput, treeStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + + test("Binary classification with continuous features: Log Loss") { + GradientBoostedTreesSuite.testCombinations.foreach { + case (numIterations, learningRate, subsamplingRate) => + val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2) + + val treeStrategy = new Strategy(algo = Classification, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = Map.empty, + subsamplingRate = subsamplingRate) + val boostingStrategy = + new BoostingStrategy(treeStrategy, LogLoss, numIterations, learningRate) + + val gbt = GradientBoostedTrees.train(rdd, boostingStrategy) + + assert(gbt.trees.size === numIterations) + try { + EnsembleTestHelper.validateClassifier(gbt, GradientBoostedTreesSuite.data, 0.9) + } catch { + case e: java.lang.AssertionError => + println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," + + s" subsamplingRate=$subsamplingRate") + throw e + } + + val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) + val ensembleStrategy = treeStrategy.copy + ensembleStrategy.algo = Regression + ensembleStrategy.impurity = Variance + val dt = DecisionTree.train(remappedInput, ensembleStrategy) + + // Make sure trees are the same. + assert(gbt.trees.head.toString == dt.toString) + } + } + +} + +object GradientBoostedTreesSuite { + + // Combinations for estimators, learning rates and subsamplingRate + val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) + + val randomSeeds = Array(681283, 4398) + + val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 20d372dc1d3ca..90a8c2dfdab80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -25,52 +25,20 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} +import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Gini, Variance} -import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} -import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.util.StatCounter +import org.apache.spark.mllib.tree.model.Node +import org.apache.spark.mllib.util.MLlibTestSparkContext /** * Test suite for [[RandomForest]]. */ -class RandomForestSuite extends FunSuite with LocalSparkContext { - - test("BaggedPoint RDD: without subsampling") { - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) - val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd) - baggedRDD.collect().foreach { baggedPoint => - assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) - } - } - - test("BaggedPoint RDD: with subsampling") { - val numSubsamples = 100 - val (expectedMean, expectedStddev) = (1.0, 1.0) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) +class RandomForestSuite extends FunSuite with MLlibTestSparkContext { + def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } - - test("Binary classification with continuous features:" + - " comparing DecisionTree vs. RandomForest(numTrees = 1)") { - - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) - val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] val numTrees = 1 - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) - val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) assert(rf.trees.size === 1) @@ -78,23 +46,34 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(rdd, strategy) - RandomForestSuite.validateClassifier(rf, arr, 0.9) + EnsembleTestHelper.validateClassifier(rf, arr, 0.9) DecisionTreeSuite.validateClassifier(dt, arr, 0.9) // Make sure trees are the same. assert(rfTree.toString == dt.toString) } - test("Regression with continuous features:" + + test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeatures(strategy) + } - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) - val rdd = sc.parallelize(arr) + test("Binary classification with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val categoricalFeaturesInfo = Map.empty[Int, Int] - val numTrees = 1 + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) + binaryClassificationTestWithContinuousFeatures(strategy) + } - val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + def regressionTestWithContinuousFeatures(strategy: Strategy) { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) + val rdd = sc.parallelize(arr) + val numTrees = 1 val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) @@ -103,21 +82,35 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { val dt = DecisionTree.train(rdd, strategy) - RandomForestSuite.validateRegressor(rf, arr, 0.01) + EnsembleTestHelper.validateRegressor(rf, arr, 0.01) DecisionTreeSuite.validateRegressor(dt, arr, 0.01) // Make sure trees are the same. assert(rfTree.toString == dt.toString) } - test("Binary classification with continuous features: subsampling features") { - val numFeatures = 50 - val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures) - val rdd = sc.parallelize(arr) + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo) + regressionTestWithContinuousFeatures(strategy) + } - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + test("Regression with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + regressionTestWithContinuousFeatures(strategy) + } + + def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) { + val numFeatures = 50 + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) + val rdd = sc.parallelize(arr) // Select feature subset for top nodes. Return true if OK. def checkFeatureSubsetStrategy( @@ -173,74 +166,36 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) } -} - -object RandomForestSuite { - - /** - * Aggregates all values in data, and tests whether the empirical mean and stddev are within - * epsilon of the expected values. - * @param data Every element of the data should be an i.i.d. sample from some distribution. - */ - def testRandomArrays( - data: Array[Array[Double]], - numCols: Int, - expectedMean: Double, - expectedStddev: Double, - epsilon: Double) { - val values = new mutable.ArrayBuffer[Double]() - data.foreach { row => - assert(row.size == numCols) - values ++= row - } - val stats = new StatCounter(values) - assert(math.abs(stats.mean - expectedMean) < epsilon) - assert(math.abs(stats.stdev - expectedStddev) < epsilon) + test("Binary classification with continuous features: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } - def validateClassifier( - model: RandomForestModel, - input: Seq[LabeledPoint], - requiredAccuracy: Double) { - val predictions = input.map(x => model.predict(x.features)) - val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => - prediction != expected.label - } - val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy, - s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + test("Binary classification with continuous features and node Id cache: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, + useNodeIdCache = true) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) } - def validateRegressor( - model: RandomForestModel, - input: Seq[LabeledPoint], - requiredMSE: Double) { - val predictions = input.map(x => model.predict(x.features)) - val squaredError = predictions.zip(input).map { case (prediction, expected) => - val err = prediction - expected.label - err * err - }.sum - val mse = squaredError / input.length - assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + test("alternating categorical and continuous features with multiclass labels to test indexing") { + val arr = new Array[LabeledPoint](4) + arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) + arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0, 1.0, 2.0)) + arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0, 6.0, 3.0)) + arr(3) = new LabeledPoint(2.0, Vectors.dense(0.0, 2.0, 1.0, 3.0, 2.0)) + val categoricalFeaturesInfo = Map(0 -> 3, 2 -> 2, 4 -> 4) + val input = sc.parallelize(arr) + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5, + numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo) + val model = RandomForest.trainClassifier(input, strategy, numTrees = 2, + featureSubsetStrategy = "sqrt", seed = 12345) + EnsembleTestHelper.validateClassifier(model, arr, 1.0) } +} - def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = { - val numInstances = 1000 - val arr = new Array[LabeledPoint](numInstances) - for (i <- 0 until numInstances) { - val label = if (i < numInstances / 10) { - 0.0 - } else if (i < numInstances / 2) { - 1.0 - } else if (i < numInstances * 0.9) { - 0.0 - } else { - 1.0 - } - val features = Array.fill[Double](numFeatures)(i.toDouble) - arr(i) = new LabeledPoint(label, Vectors.dense(features)) - } - arr - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala new file mode 100644 index 0000000000000..b184e936672ca --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.tree.EnsembleTestHelper +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[BaggedPoint]]. + */ +class BaggedPointSuite extends FunSuite with MLlibTestSparkContext { + + test("BaggedPoint RDD: without subsampling") { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) + baggedRDD.collect().foreach { baggedPoint => + assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 1.0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample)) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample))) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala deleted file mode 100644 index 7857d9e5ee5c4..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalSparkContext.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.util - -import org.scalatest.Suite -import org.scalatest.BeforeAndAfterAll - -import org.apache.spark.{SparkConf, SparkContext} - -trait LocalSparkContext extends BeforeAndAfterAll { self: Suite => - @transient var sc: SparkContext = _ - - override def beforeAll() { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - sc = new SparkContext(conf) - super.beforeAll() - } - - override def afterAll() { - if (sc != null) { - sc.stop() - } - super.afterAll() - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 8ef2bb1bf6a78..88bc49cc61f94 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils._ import org.apache.spark.util.Utils -class MLUtilsSuite extends FunSuite with LocalSparkContext { +class MLUtilsSuite extends FunSuite with MLlibTestSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") @@ -67,8 +67,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { |0 |0 2:4.0 4:5.0 6:6.0 """.stripMargin - val tempDir = Files.createTempDir() - tempDir.deleteOnExit() + val tempDir = Utils.createTempDir() val file = new File(tempDir.getPath, "part-00000") Files.write(lines, file, Charsets.US_ASCII) val path = tempDir.toURI.toString @@ -100,7 +99,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(1.1, Vectors.sparse(3, Seq((0, 1.23), (2, 4.56)))), LabeledPoint(0.0, Vectors.dense(1.01, 2.02, 3.03)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "output") MLUtils.saveAsLibSVMFile(examples, outputDir.toURI.toString) val lines = outputDir.listFiles() @@ -166,7 +165,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { Vectors.sparse(2, Array(1), Array(-1.0)), Vectors.dense(0.0, 1.0) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "vectors") val path = outputDir.toURI.toString vectors.saveAsTextFile(path) @@ -181,7 +180,7 @@ class MLUtilsSuite extends FunSuite with LocalSparkContext { LabeledPoint(0.0, Vectors.sparse(2, Array(1), Array(-1.0))), LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) ), 2) - val tempDir = Files.createTempDir() + val tempDir = Utils.createTempDir() val outputDir = new File(tempDir, "points") val path = outputDir.toURI.toString points.saveAsTextFile(path) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala new file mode 100644 index 0000000000000..b658889476d37 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.scalatest.Suite +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{SparkConf, SparkContext} + +trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => + @transient var sc: SparkContext = _ + + override def beforeAll() { + super.beforeAll() + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("MLlibUnitTest") + sc = new SparkContext(conf) + } + + override def afterAll() { + if (sc != null) { + sc.stop() + } + super.afterAll() + } +} diff --git a/network/common/pom.xml b/network/common/pom.xml new file mode 100644 index 0000000000000..baca859fa5011 --- /dev/null +++ b/network/common/pom.xml @@ -0,0 +1,111 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.3.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-common_2.10 + jar + Spark Project Networking + http://spark.apache.org/ + + network-common + + + + + + io.netty + netty-all + + + + + org.slf4j + slf4j-api + provided + + + com.google.guava + guava + provided + + + + + junit + junit + test + + + com.novocode + junit-interface + test + + + log4j + log4j + test + + + org.mockito + mockito-all + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + + org.apache.maven.plugins + maven-jar-plugin + 2.2 + + + + test-jar + + + + test-jar-on-test-compile + test-compile + + test-jar + + + + + + + diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java new file mode 100644 index 0000000000000..5bc6e5a2418a9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.List; + +import com.google.common.collect.Lists; +import io.netty.channel.Channel; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.MessageDecoder; +import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportChannelHandler; +import org.apache.spark.network.server.TransportRequestHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to + * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}. + * + * There are two communication protocols that the TransportClient provides, control-plane RPCs and + * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the + * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams + * which can be streamed through the data plane in chunks using zero-copy IO. + * + * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each + * channel. As each TransportChannelHandler contains a TransportClient, this enables server + * processes to send messages back to the client on an existing channel. + */ +public class TransportContext { + private final Logger logger = LoggerFactory.getLogger(TransportContext.class); + + private final TransportConf conf; + private final RpcHandler rpcHandler; + + private final MessageEncoder encoder; + private final MessageDecoder decoder; + + public TransportContext(TransportConf conf, RpcHandler rpcHandler) { + this.conf = conf; + this.rpcHandler = rpcHandler; + this.encoder = new MessageEncoder(); + this.decoder = new MessageDecoder(); + } + + /** + * Initializes a ClientFactory which runs the given TransportClientBootstraps prior to returning + * a new Client. Bootstraps will be executed synchronously, and must run successfully in order + * to create a Client. + */ + public TransportClientFactory createClientFactory(List bootstraps) { + return new TransportClientFactory(this, bootstraps); + } + + public TransportClientFactory createClientFactory() { + return createClientFactory(Lists.newArrayList()); + } + + /** Create a server which will attempt to bind to a specific port. */ + public TransportServer createServer(int port) { + return new TransportServer(this, port); + } + + /** Creates a new server, binding to any available ephemeral port. */ + public TransportServer createServer() { + return new TransportServer(this, 0); + } + + /** + * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and + * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or + * response messages. + * + * @return Returns the created TransportChannelHandler, which includes a TransportClient that can + * be used to communicate on this channel. The TransportClient is directly associated with a + * ChannelHandler to ensure all users of the same channel get the same TransportClient object. + */ + public TransportChannelHandler initializePipeline(SocketChannel channel) { + try { + TransportChannelHandler channelHandler = createChannelHandler(channel); + channel.pipeline() + .addLast("encoder", encoder) + .addLast("frameDecoder", NettyUtils.createFrameDecoder()) + .addLast("decoder", decoder) + // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this + // would require more logic to guarantee if this were not part of the same event loop. + .addLast("handler", channelHandler); + return channelHandler; + } catch (RuntimeException e) { + logger.error("Error while initializing Netty pipeline", e); + throw e; + } + } + + /** + * Creates the server- and client-side handler which is used to handle both RequestMessages and + * ResponseMessages. The channel is expected to have been successfully created, though certain + * properties (such as the remoteAddress()) may not be available yet. + */ + private TransportChannelHandler createChannelHandler(Channel channel) { + TransportResponseHandler responseHandler = new TransportResponseHandler(channel); + TransportClient client = new TransportClient(channel, responseHandler); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client, + rpcHandler); + return new TransportChannelHandler(client, responseHandler, requestHandler); + } + + public TransportConf getConf() { return conf; } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java new file mode 100644 index 0000000000000..844eff4f4c701 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; + +import com.google.common.base.Objects; +import com.google.common.io.ByteStreams; +import io.netty.channel.DefaultFileRegion; + +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.LimitedInputStream; +import org.apache.spark.network.util.TransportConf; + +/** + * A {@link ManagedBuffer} backed by a segment in a file. + */ +public final class FileSegmentManagedBuffer extends ManagedBuffer { + private final TransportConf conf; + private final File file; + private final long offset; + private final long length; + + public FileSegmentManagedBuffer(TransportConf conf, File file, long offset, long length) { + this.conf = conf; + this.file = file; + this.offset = offset; + this.length = length; + } + + @Override + public long size() { + return length; + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + FileChannel channel = null; + try { + channel = new RandomAccessFile(file, "r").getChannel(); + // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead. + if (length < conf.memoryMapBytes()) { + ByteBuffer buf = ByteBuffer.allocate((int) length); + channel.position(offset); + while (buf.remaining() != 0) { + if (channel.read(buf) == -1) { + throw new IOException(String.format("Reached EOF before filling buffer\n" + + "offset=%s\nfile=%s\nbuf.remaining=%s", + offset, file.getAbsoluteFile(), buf.remaining())); + } + } + buf.flip(); + return buf; + } else { + return channel.map(FileChannel.MapMode.READ_ONLY, offset, length); + } + } catch (IOException e) { + try { + if (channel != null) { + long size = channel.size(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } + throw new IOException("Error in opening " + this, e); + } finally { + JavaUtils.closeQuietly(channel); + } + } + + @Override + public InputStream createInputStream() throws IOException { + FileInputStream is = null; + try { + is = new FileInputStream(file); + ByteStreams.skipFully(is, offset); + return new LimitedInputStream(is, length); + } catch (IOException e) { + try { + if (is != null) { + long size = file.length(); + throw new IOException("Error in reading " + this + " (actual file length " + size + ")", + e); + } + } catch (IOException ignored) { + // ignore + } finally { + JavaUtils.closeQuietly(is); + } + throw new IOException("Error in opening " + this, e); + } catch (RuntimeException e) { + JavaUtils.closeQuietly(is); + throw e; + } + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + if (conf.lazyFileDescriptor()) { + return new LazyFileRegion(file, offset, length); + } else { + FileChannel fileChannel = new FileInputStream(file).getChannel(); + return new DefaultFileRegion(fileChannel, offset, length); + } + } + + public File getFile() { return file; } + + public long getOffset() { return offset; } + + public long getLength() { return length; } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", file) + .add("offset", offset) + .add("length", length) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java new file mode 100644 index 0000000000000..81bc8ec40fc82 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/LazyFileRegion.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.FileInputStream; +import java.io.File; +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +import com.google.common.base.Objects; +import io.netty.channel.FileRegion; +import io.netty.util.AbstractReferenceCounted; + +import org.apache.spark.network.util.JavaUtils; + +/** + * A FileRegion implementation that only creates the file descriptor when the region is being + * transferred. This cannot be used with Epoll because there is no native support for it. + * + * This is mostly copied from DefaultFileRegion implementation in Netty. In the future, we + * should push this into Netty so the native Epoll transport can support this feature. + */ +public final class LazyFileRegion extends AbstractReferenceCounted implements FileRegion { + + private final File file; + private final long position; + private final long count; + + private FileChannel channel; + + private long numBytesTransferred = 0L; + + /** + * @param file file to transfer. + * @param position start position for the transfer. + * @param count number of bytes to transfer starting from position. + */ + public LazyFileRegion(File file, long position, long count) { + this.file = file; + this.position = position; + this.count = count; + } + + @Override + protected void deallocate() { + JavaUtils.closeQuietly(channel); + } + + @Override + public long position() { + return position; + } + + @Override + public long transfered() { + return numBytesTransferred; + } + + @Override + public long count() { + return count; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + if (channel == null) { + channel = new FileInputStream(file).getChannel(); + } + + long count = this.count - position; + if (count < 0 || position < 0) { + throw new IllegalArgumentException( + "position out of range: " + position + " (expected: 0 - " + (count - 1) + ')'); + } + + if (count == 0) { + return 0L; + } + + long written = channel.transferTo(this.position + position, count, target); + if (written > 0) { + numBytesTransferred += written; + } + return written; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("file", file) + .add("position", position) + .add("count", count) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java new file mode 100644 index 0000000000000..a415db593a788 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +/** + * This interface provides an immutable view for data in the form of bytes. The implementation + * should specify how the data is provided: + * + * - {@link FileSegmentManagedBuffer}: data backed by part of a file + * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer + * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf + * + * The concrete buffer implementation might be managed outside the JVM garbage collector. + * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted. + * In that case, if the buffer is going to be passed around to a different thread, retain/release + * should be called. + */ +public abstract class ManagedBuffer { + + /** Number of bytes of the data. */ + public abstract long size(); + + /** + * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the + * returned ByteBuffer should not affect the content of this buffer. + */ + // TODO: Deprecate this, usage may require expensive memory mapping or allocation. + public abstract ByteBuffer nioByteBuffer() throws IOException; + + /** + * Exposes this buffer's data as an InputStream. The underlying implementation does not + * necessarily check for the length of bytes read, so the caller is responsible for making sure + * it does not go over the limit. + */ + public abstract InputStream createInputStream() throws IOException; + + /** + * Increment the reference count by one if applicable. + */ + public abstract ManagedBuffer retain(); + + /** + * If applicable, decrement the reference count by one and deallocates the buffer if the + * reference count reaches zero. + */ + public abstract ManagedBuffer release(); + + /** + * Convert the buffer into an Netty object, used to write the data out. + */ + public abstract Object convertToNetty() throws IOException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java new file mode 100644 index 0000000000000..c806bfa45bef3 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; + +/** + * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}. + */ +public final class NettyManagedBuffer extends ManagedBuffer { + private final ByteBuf buf; + + public NettyManagedBuffer(ByteBuf buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.readableBytes(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.nioBuffer(); + } + + @Override + public InputStream createInputStream() throws IOException { + return new ByteBufInputStream(buf); + } + + @Override + public ManagedBuffer retain() { + buf.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + buf.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return buf.duplicate(); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java new file mode 100644 index 0000000000000..f55b884bc45ce --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.buffer; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; + +/** + * A {@link ManagedBuffer} backed by {@link ByteBuffer}. + */ +public final class NioManagedBuffer extends ManagedBuffer { + private final ByteBuffer buf; + + public NioManagedBuffer(ByteBuffer buf) { + this.buf = buf; + } + + @Override + public long size() { + return buf.remaining(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return buf.duplicate(); + } + + @Override + public InputStream createInputStream() throws IOException { + return new ByteBufInputStream(Unpooled.wrappedBuffer(buf)); + } + + @Override + public ManagedBuffer retain() { + return this; + } + + @Override + public ManagedBuffer release() { + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return Unpooled.wrappedBuffer(buf); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("buf", buf) + .toString(); + } +} + diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java new file mode 100644 index 0000000000000..1fbdcd6780785 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * General exception caused by a remote exception while fetching a chunk. + */ +public class ChunkFetchFailureException extends RuntimeException { + public ChunkFetchFailureException(String errorMsg, Throwable cause) { + super(errorMsg, cause); + } + + public ChunkFetchFailureException(String errorMsg) { + super(errorMsg); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java new file mode 100644 index 0000000000000..519e6cb470d0d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Callback for the result of a single chunk result. For a single stream, the callbacks are + * guaranteed to be called by the same thread in the same order as the requests for chunks were + * made. + * + * Note that if a general stream failure occurs, all outstanding chunk requests may be failed. + */ +public interface ChunkReceivedCallback { + /** + * Called upon receipt of a particular chunk. + * + * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this + * call returns. You must therefore either retain() the buffer or copy its contents before + * returning. + */ + void onSuccess(int chunkIndex, ManagedBuffer buffer); + + /** + * Called upon failure to fetch a particular chunk. Note that this may actually be called due + * to failure to fetch a prior chunk in this stream. + * + * After receiving a failure, the stream may or may not be valid. The client should not assume + * that the server's side of the stream has been closed. + */ + void onFailure(int chunkIndex, Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java new file mode 100644 index 0000000000000..6ec960d795420 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * Callback for the result of a single RPC. This will be invoked once with either success or + * failure. + */ +public interface RpcResponseCallback { + /** Successful serialized result from server. */ + void onSuccess(byte[] response); + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java new file mode 100644 index 0000000000000..37f2e34ceb24d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.io.IOException; +import java.util.UUID; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.SettableFuture; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.util.NettyUtils; + +/** + * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow + * efficient transfer of a large amount of data, broken up into chunks with size ranging from + * hundreds of KB to a few MB. + * + * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane), + * the actual setup of the streams is done outside the scope of the transport layer. The convenience + * method "sendRPC" is provided to enable control plane communication between the client and server + * to perform this setup. + * + * For example, a typical workflow might be: + * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100 + * client.fetchChunk(streamId = 100, chunkIndex = 0, callback) + * client.fetchChunk(streamId = 100, chunkIndex = 1, callback) + * ... + * client.sendRPC(new CloseStream(100)) + * + * Construct an instance of TransportClient using {@link TransportClientFactory}. A single + * TransportClient may be used for multiple streams, but any given stream must be restricted to a + * single client, in order to avoid out-of-order responses. + * + * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is + * responsible for handling responses from the server. + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class TransportClient implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportClient.class); + + private final Channel channel; + private final TransportResponseHandler handler; + + public TransportClient(Channel channel, TransportResponseHandler handler) { + this.channel = Preconditions.checkNotNull(channel); + this.handler = Preconditions.checkNotNull(handler); + } + + public boolean isActive() { + return channel.isOpen() || channel.isActive(); + } + + /** + * Requests a single chunk from the remote side, from the pre-negotiated streamId. + * + * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though + * some streams may not support this. + * + * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed + * to be returned in the same order that they were requested, assuming only a single + * TransportClient is used to fetch the chunks. + * + * @param streamId Identifier that refers to a stream in the remote StreamManager. This should + * be agreed upon by client and server beforehand. + * @param chunkIndex 0-based index of the chunk to fetch + * @param callback Callback invoked upon successful receipt of chunk, or upon any failure. + */ + public void fetchChunk( + long streamId, + final int chunkIndex, + final ChunkReceivedCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); + final long startTime = System.currentTimeMillis(); + logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr); + + final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex); + handler.addFetchRequest(streamChunkId, callback); + + channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr, + timeTaken); + } else { + String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId, + serverAddr, future.cause()); + logger.error(errorMsg, future.cause()); + handler.removeFetchRequest(streamChunkId); + channel.close(); + try { + callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + }); + } + + /** + * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked + * with the server's response or upon any failure. + */ + public void sendRpc(byte[] message, final RpcResponseCallback callback) { + final String serverAddr = NettyUtils.getRemoteAddress(channel); + final long startTime = System.currentTimeMillis(); + logger.trace("Sending RPC to {}", serverAddr); + + final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits()); + handler.addRpcRequest(requestId, callback); + + channel.writeAndFlush(new RpcRequest(requestId, message)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + long timeTaken = System.currentTimeMillis() - startTime; + logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken); + } else { + String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId, + serverAddr, future.cause()); + logger.error(errorMsg, future.cause()); + handler.removeRpcRequest(requestId); + channel.close(); + try { + callback.onFailure(new IOException(errorMsg, future.cause())); + } catch (Exception e) { + logger.error("Uncaught exception in RPC response callback handler!", e); + } + } + } + }); + } + + /** + * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to + * a specified timeout for a response. + */ + public byte[] sendRpcSync(byte[] message, long timeoutMs) { + final SettableFuture result = SettableFuture.create(); + + sendRpc(message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + result.set(response); + } + + @Override + public void onFailure(Throwable e) { + result.setException(e); + } + }); + + try { + return result.get(timeoutMs, TimeUnit.MILLISECONDS); + } catch (ExecutionException e) { + throw Throwables.propagate(e.getCause()); + } catch (Exception e) { + throw Throwables.propagate(e); + } + } + + @Override + public void close() { + // close is a local operation and should finish with milliseconds; timeout just to be safe + channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("remoteAdress", channel.remoteAddress()) + .add("isActive", isActive()) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java new file mode 100644 index 0000000000000..65e8020e34121 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A bootstrap which is executed on a TransportClient before it is returned to the user. + * This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per- + * connection basis. + * + * Since connections (and TransportClients) are reused as much as possible, it is generally + * reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with + * the JVM itself. + */ +public interface TransportClientBootstrap { + /** Performs the bootstrapping operation, throwing an exception on failure. */ + public void doBootstrap(TransportClient client) throws RuntimeException; +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java new file mode 100644 index 0000000000000..9afd5decd5e6b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.Closeable; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.Lists; +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.server.TransportChannelHandler; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Factory for creating {@link TransportClient}s by using createClient. + * + * The factory maintains a connection pool to other hosts and should return the same + * TransportClient for the same remote host. It also shares a single worker thread pool for + * all TransportClients. + * + * TransportClients will be reused whenever possible. Prior to completing the creation of a new + * TransportClient, all given {@link TransportClientBootstrap}s will be run. + */ +public class TransportClientFactory implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class); + + private final TransportContext context; + private final TransportConf conf; + private final List clientBootstraps; + private final ConcurrentHashMap connectionPool; + + private final Class socketChannelClass; + private EventLoopGroup workerGroup; + private PooledByteBufAllocator pooledAllocator; + + public TransportClientFactory( + TransportContext context, + List clientBootstraps) { + this.context = Preconditions.checkNotNull(context); + this.conf = context.getConf(); + this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); + this.connectionPool = new ConcurrentHashMap(); + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode); + // TODO: Make thread pool name configurable. + this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client"); + this.pooledAllocator = NettyUtils.createPooledByteBufAllocator( + conf.preferDirectBufs(), false /* allowCache */, conf.clientThreads()); + } + + /** + * Create a new {@link TransportClient} connecting to the given remote host / port. This will + * reuse TransportClients if they are still active and are for the same remote address. Prior + * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s + * that are registered with this factory. + * + * This blocks until a connection is successfully established and fully bootstrapped. + * + * Concurrency: This method is safe to call from multiple threads. + */ + public TransportClient createClient(String remoteHost, int remotePort) throws IOException { + // Get connection from the connection pool first. + // If it is not found or not active, create a new one. + final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + TransportClient cachedClient = connectionPool.get(address); + if (cachedClient != null) { + if (cachedClient.isActive()) { + logger.trace("Returning cached connection to {}: {}", address, cachedClient); + return cachedClient; + } else { + logger.info("Found inactive connection to {}, closing it.", address); + connectionPool.remove(address, cachedClient); // Remove inactive clients. + } + } + + logger.debug("Creating new connection to " + address); + + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(workerGroup) + .channel(socketChannelClass) + // Disable Nagle's Algorithm since we don't want packets to wait + .option(ChannelOption.TCP_NODELAY, true) + .option(ChannelOption.SO_KEEPALIVE, true) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs()) + .option(ChannelOption.ALLOCATOR, pooledAllocator); + + final AtomicReference clientRef = new AtomicReference(); + + bootstrap.handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) { + TransportChannelHandler clientHandler = context.initializePipeline(ch); + clientRef.set(clientHandler.getClient()); + } + }); + + // Connect to the remote server + long preConnect = System.currentTimeMillis(); + ChannelFuture cf = bootstrap.connect(address); + if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) { + throw new IOException( + String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs())); + } else if (cf.cause() != null) { + throw new IOException(String.format("Failed to connect to %s", address), cf.cause()); + } + + TransportClient client = clientRef.get(); + assert client != null : "Channel future completed successfully with null client"; + + // Execute any client bootstraps synchronously before marking the Client as successful. + long preBootstrap = System.currentTimeMillis(); + logger.debug("Connection to {} successful, running bootstraps...", address); + try { + for (TransportClientBootstrap clientBootstrap : clientBootstraps) { + clientBootstrap.doBootstrap(client); + } + } catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala + long bootstrapTime = System.currentTimeMillis() - preBootstrap; + logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e); + client.close(); + throw Throwables.propagate(e); + } + long postBootstrap = System.currentTimeMillis(); + + // Successful connection & bootstrap -- in the event that two threads raced to create a client, + // use the first one that was put into the connectionPool and close the one we made here. + TransportClient oldClient = connectionPool.putIfAbsent(address, client); + if (oldClient == null) { + logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)", + address, postBootstrap - preConnect, postBootstrap - preBootstrap); + return client; + } else { + logger.debug("Two clients were created concurrently after {} ms, second will be disposed.", + postBootstrap - preConnect); + client.close(); + return oldClient; + } + } + + /** Close all connections in the connection pool, and shutdown the worker thread pool. */ + @Override + public void close() { + for (TransportClient client : connectionPool.values()) { + try { + client.close(); + } catch (RuntimeException e) { + logger.warn("Ignoring exception during close", e); + } + } + connectionPool.clear(); + + if (workerGroup != null) { + workerGroup.shutdownGracefully(); + workerGroup = null; + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java new file mode 100644 index 0000000000000..2044afb0d85db --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import java.io.IOException; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import com.google.common.annotations.VisibleForTesting; +import io.netty.channel.Channel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.ResponseMessage; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.server.MessageHandler; +import org.apache.spark.network.util.NettyUtils; + +/** + * Handler that processes server responses, in response to requests issued from a + * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks). + * + * Concurrency: thread safe and can be called from multiple threads. + */ +public class TransportResponseHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class); + + private final Channel channel; + + private final Map outstandingFetches; + + private final Map outstandingRpcs; + + public TransportResponseHandler(Channel channel) { + this.channel = channel; + this.outstandingFetches = new ConcurrentHashMap(); + this.outstandingRpcs = new ConcurrentHashMap(); + } + + public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) { + outstandingFetches.put(streamChunkId, callback); + } + + public void removeFetchRequest(StreamChunkId streamChunkId) { + outstandingFetches.remove(streamChunkId); + } + + public void addRpcRequest(long requestId, RpcResponseCallback callback) { + outstandingRpcs.put(requestId, callback); + } + + public void removeRpcRequest(long requestId) { + outstandingRpcs.remove(requestId); + } + + /** + * Fire the failure callback for all outstanding requests. This is called when we have an + * uncaught exception or pre-mature connection termination. + */ + private void failOutstandingRequests(Throwable cause) { + for (Map.Entry entry : outstandingFetches.entrySet()) { + entry.getValue().onFailure(entry.getKey().chunkIndex, cause); + } + for (Map.Entry entry : outstandingRpcs.entrySet()) { + entry.getValue().onFailure(cause); + } + + // It's OK if new fetches appear, as they will fail immediately. + outstandingFetches.clear(); + outstandingRpcs.clear(); + } + + @Override + public void channelUnregistered() { + if (numOutstandingRequests() > 0) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + logger.error("Still have {} requests outstanding when connection from {} is closed", + numOutstandingRequests(), remoteAddress); + failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed")); + } + } + + @Override + public void exceptionCaught(Throwable cause) { + if (numOutstandingRequests() > 0) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + logger.error("Still have {} requests outstanding when connection from {} is closed", + numOutstandingRequests(), remoteAddress); + failOutstandingRequests(cause); + } + } + + @Override + public void handle(ResponseMessage message) { + String remoteAddress = NettyUtils.getRemoteAddress(channel); + if (message instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Ignoring response for block {} from {} since it is not outstanding", + resp.streamChunkId, remoteAddress); + resp.buffer.release(); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer); + resp.buffer.release(); + } + } else if (message instanceof ChunkFetchFailure) { + ChunkFetchFailure resp = (ChunkFetchFailure) message; + ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId); + if (listener == null) { + logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding", + resp.streamChunkId, remoteAddress, resp.errorString); + } else { + outstandingFetches.remove(resp.streamChunkId); + listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException( + "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString)); + } + } else if (message instanceof RpcResponse) { + RpcResponse resp = (RpcResponse) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", + resp.requestId, remoteAddress, resp.response.length); + } else { + outstandingRpcs.remove(resp.requestId); + listener.onSuccess(resp.response); + } + } else if (message instanceof RpcFailure) { + RpcFailure resp = (RpcFailure) message; + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", + resp.requestId, remoteAddress, resp.errorString); + } else { + outstandingRpcs.remove(resp.requestId); + listener.onFailure(new RuntimeException(resp.errorString)); + } + } else { + throw new IllegalStateException("Unknown response type: " + message.type()); + } + } + + /** Returns total number of outstanding requests (fetch requests + rpcs) */ + @VisibleForTesting + public int numOutstandingRequests() { + return outstandingFetches.size() + outstandingRpcs.size(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java new file mode 100644 index 0000000000000..986957c1509fd --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Charsets; +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk. + */ +public final class ChunkFetchFailure implements ResponseMessage { + public final StreamChunkId streamChunkId; + public final String errorString; + + public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) { + this.streamChunkId = streamChunkId; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.ChunkFetchFailure; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength() + Encoders.Strings.encodedLength(errorString); + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + Encoders.Strings.encode(buf, errorString); + } + + public static ChunkFetchFailure decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + String errorString = Encoders.Strings.decode(buf); + return new ChunkFetchFailure(streamChunkId, errorString); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchFailure) { + ChunkFetchFailure o = (ChunkFetchFailure) other; + return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java new file mode 100644 index 0000000000000..980947cf13f6b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). + */ +public final class ChunkFetchRequest implements RequestMessage { + public final StreamChunkId streamChunkId; + + public ChunkFetchRequest(StreamChunkId streamChunkId) { + this.streamChunkId = streamChunkId; + } + + @Override + public Type type() { return Type.ChunkFetchRequest; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + public static ChunkFetchRequest decode(ByteBuf buf) { + return new ChunkFetchRequest(StreamChunkId.decode(buf)); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchRequest) { + ChunkFetchRequest o = (ChunkFetchRequest) other; + return streamChunkId.equals(o.streamChunkId); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java new file mode 100644 index 0000000000000..ff4936470c697 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched. + * + * Note that the server-side encoding of this messages does NOT include the buffer itself, as this + * may be written by Netty in a more efficient manner (i.e., zero-copy write). + * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer. + */ +public final class ChunkFetchSuccess implements ResponseMessage { + public final StreamChunkId streamChunkId; + public final ManagedBuffer buffer; + + public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) { + this.streamChunkId = streamChunkId; + this.buffer = buffer; + } + + @Override + public Type type() { return Type.ChunkFetchSuccess; } + + @Override + public int encodedLength() { + return streamChunkId.encodedLength(); + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + streamChunkId.encode(buf); + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static ChunkFetchSuccess decode(ByteBuf buf) { + StreamChunkId streamChunkId = StreamChunkId.decode(buf); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new ChunkFetchSuccess(streamChunkId, managedBuf); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ChunkFetchSuccess) { + ChunkFetchSuccess o = (ChunkFetchSuccess) other; + return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamChunkId", streamChunkId) + .add("buffer", buffer) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java new file mode 100644 index 0000000000000..b4e299471b41a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +/** + * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are + * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length. + * + * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by + * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than + * just copying data from it), then you must retain() the ByteBuf. + * + * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}. + */ +public interface Encodable { + /** Number of bytes of the encoded form of this object. */ + int encodedLength(); + + /** + * Serializes this object by writing into the given ByteBuf. + * This method must write exactly encodedLength() bytes. + */ + void encode(ByteBuf buf); +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java new file mode 100644 index 0000000000000..873c694250942 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + + +import com.google.common.base.Charsets; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +/** Provides a canonical set of Encoders for simple types. */ +public class Encoders { + + /** Strings are encoded with their length followed by UTF-8 bytes. */ + public static class Strings { + public static int encodedLength(String s) { + return 4 + s.getBytes(Charsets.UTF_8).length; + } + + public static void encode(ByteBuf buf, String s) { + byte[] bytes = s.getBytes(Charsets.UTF_8); + buf.writeInt(bytes.length); + buf.writeBytes(bytes); + } + + public static String decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return new String(bytes, Charsets.UTF_8); + } + } + + /** Byte arrays are encoded with their length followed by bytes. */ + public static class ByteArrays { + public static int encodedLength(byte[] arr) { + return 4 + arr.length; + } + + public static void encode(ByteBuf buf, byte[] arr) { + buf.writeInt(arr.length); + buf.writeBytes(arr); + } + + public static byte[] decode(ByteBuf buf) { + int length = buf.readInt(); + byte[] bytes = new byte[length]; + buf.readBytes(bytes); + return bytes; + } + } + + /** String arrays are encoded with the number of strings followed by per-String encoding. */ + public static class StringArrays { + public static int encodedLength(String[] strings) { + int totalLength = 4; + for (String s : strings) { + totalLength += Strings.encodedLength(s); + } + return totalLength; + } + + public static void encode(ByteBuf buf, String[] strings) { + buf.writeInt(strings.length); + for (String s : strings) { + Strings.encode(buf, s); + } + } + + public static String[] decode(ByteBuf buf) { + int numStrings = buf.readInt(); + String[] strings = new String[numStrings]; + for (int i = 0; i < strings.length; i ++) { + strings[i] = Strings.decode(buf); + } + return strings; + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java new file mode 100644 index 0000000000000..d568370125fd4 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; + +/** An on-the-wire transmittable message. */ +public interface Message extends Encodable { + /** Used to identify this request type. */ + Type type(); + + /** Preceding every serialized Message is its type, which allows us to deserialize it. */ + public static enum Type implements Encodable { + ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2), + RpcRequest(3), RpcResponse(4), RpcFailure(5); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { return id; } + + @Override public int encodedLength() { return 1; } + + @Override public void encode(ByteBuf buf) { buf.writeByte(id); } + + public static Type decode(ByteBuf buf) { + byte id = buf.readByte(); + switch (id) { + case 0: return ChunkFetchRequest; + case 1: return ChunkFetchSuccess; + case 2: return ChunkFetchFailure; + case 3: return RpcRequest; + case 4: return RpcResponse; + case 5: return RpcFailure; + default: throw new IllegalArgumentException("Unknown message type: " + id); + } + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java new file mode 100644 index 0000000000000..81f8d7f96350f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Decoder used by the client side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class MessageDecoder extends MessageToMessageDecoder { + + private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class); + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + Message.Type msgType = Message.Type.decode(in); + Message decoded = decode(msgType, in); + assert decoded.type() == msgType; + logger.trace("Received message " + msgType + ": " + decoded); + out.add(decoded); + } + + private Message decode(Message.Type msgType, ByteBuf in) { + switch (msgType) { + case ChunkFetchRequest: + return ChunkFetchRequest.decode(in); + + case ChunkFetchSuccess: + return ChunkFetchSuccess.decode(in); + + case ChunkFetchFailure: + return ChunkFetchFailure.decode(in); + + case RpcRequest: + return RpcRequest.decode(in); + + case RpcResponse: + return RpcResponse.decode(in); + + case RpcFailure: + return RpcFailure.decode(in); + + default: + throw new IllegalArgumentException("Unexpected message type: " + msgType); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java new file mode 100644 index 0000000000000..91d1e8a538a77 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Encoder used by the server side to encode server-to-client responses. + * This encoder is stateless so it is safe to be shared by multiple threads. + */ +@ChannelHandler.Sharable +public final class MessageEncoder extends MessageToMessageEncoder { + + private final Logger logger = LoggerFactory.getLogger(MessageEncoder.class); + + /*** + * Encodes a Message by invoking its encode() method. For non-data messages, we will add one + * ByteBuf to 'out' containing the total frame length, the message type, and the message itself. + * In the case of a ChunkFetchSuccess, we will also add the ManagedBuffer corresponding to the + * data to 'out', in order to enable zero-copy transfer. + */ + @Override + public void encode(ChannelHandlerContext ctx, Message in, List out) { + Object body = null; + long bodyLength = 0; + + // Only ChunkFetchSuccesses have data besides the header. + // The body is used in order to enable zero-copy transfer for the payload. + if (in instanceof ChunkFetchSuccess) { + ChunkFetchSuccess resp = (ChunkFetchSuccess) in; + try { + bodyLength = resp.buffer.size(); + body = resp.buffer.convertToNetty(); + } catch (Exception e) { + // Re-encode this message as BlockFetchFailure. + logger.error(String.format("Error opening block %s for client %s", + resp.streamChunkId, ctx.channel().remoteAddress()), e); + encode(ctx, new ChunkFetchFailure(resp.streamChunkId, e.getMessage()), out); + return; + } + } + + Message.Type msgType = in.type(); + // All messages have the frame length, message type, and message itself. + int headerLength = 8 + msgType.encodedLength() + in.encodedLength(); + long frameLength = headerLength + bodyLength; + ByteBuf header = ctx.alloc().heapBuffer(headerLength); + header.writeLong(frameLength); + msgType.encode(header); + in.encode(header); + assert header.writableBytes() == 0; + + out.add(header); + if (body != null && bodyLength > 0) { + out.add(body); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java new file mode 100644 index 0000000000000..31b15bb17a327 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RequestMessage.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import org.apache.spark.network.protocol.Message; + +/** Messages from the client to the server. */ +public interface RequestMessage extends Message { + // token interface +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java new file mode 100644 index 0000000000000..6edffd11cf1e2 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/ResponseMessage.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import org.apache.spark.network.protocol.Message; + +/** Messages from the server to the client. */ +public interface ResponseMessage extends Message { + // token interface +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java new file mode 100644 index 0000000000000..ebd764eb5eb5f --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcFailure.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Charsets; +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link RpcRequest} for a failed RPC. */ +public final class RpcFailure implements ResponseMessage { + public final long requestId; + public final String errorString; + + public RpcFailure(long requestId, String errorString) { + this.requestId = requestId; + this.errorString = errorString; + } + + @Override + public Type type() { return Type.RpcFailure; } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(errorString); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + Encoders.Strings.encode(buf, errorString); + } + + public static RpcFailure decode(ByteBuf buf) { + long requestId = buf.readLong(); + String errorString = Encoders.Strings.decode(buf); + return new RpcFailure(requestId, errorString); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcFailure) { + RpcFailure o = (RpcFailure) other; + return requestId == o.requestId && errorString.equals(o.errorString); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("errorString", errorString) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java new file mode 100644 index 0000000000000..cdee0b0e0316b --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcRequest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** + * A generic RPC which is handled by a remote {@link org.apache.spark.network.server.RpcHandler}. + * This will correspond to a single + * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure). + */ +public final class RpcRequest implements RequestMessage { + /** Used to link an RPC request with its response. */ + public final long requestId; + + /** Serialized message to send to remote RpcHandler. */ + public final byte[] message; + + public RpcRequest(long requestId, byte[] message) { + this.requestId = requestId; + this.message = message; + } + + @Override + public Type type() { return Type.RpcRequest; } + + @Override + public int encodedLength() { + return 8 + Encoders.ByteArrays.encodedLength(message); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + Encoders.ByteArrays.encode(buf, message); + } + + public static RpcRequest decode(ByteBuf buf) { + long requestId = buf.readLong(); + byte[] message = Encoders.ByteArrays.decode(buf); + return new RpcRequest(requestId, message); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcRequest) { + RpcRequest o = (RpcRequest) other; + return requestId == o.requestId && Arrays.equals(message, o.message); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("message", message) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java new file mode 100644 index 0000000000000..0a62e09a8115c --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/RpcResponse.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** Response to {@link RpcRequest} for a successful RPC. */ +public final class RpcResponse implements ResponseMessage { + public final long requestId; + public final byte[] response; + + public RpcResponse(long requestId, byte[] response) { + this.requestId = requestId; + this.response = response; + } + + @Override + public Type type() { return Type.RpcResponse; } + + @Override + public int encodedLength() { return 8 + Encoders.ByteArrays.encodedLength(response); } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + Encoders.ByteArrays.encode(buf, response); + } + + public static RpcResponse decode(ByteBuf buf) { + long requestId = buf.readLong(); + byte[] response = Encoders.ByteArrays.decode(buf); + return new RpcResponse(requestId, response); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RpcResponse) { + RpcResponse o = (RpcResponse) other; + return requestId == o.requestId && Arrays.equals(response, o.response); + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("requestId", requestId) + .add("response", response) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java new file mode 100644 index 0000000000000..d46a263884807 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/protocol/StreamChunkId.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +/** +* Encapsulates a request for a particular chunk of a stream. +*/ +public final class StreamChunkId implements Encodable { + public final long streamId; + public final int chunkIndex; + + public StreamChunkId(long streamId, int chunkIndex) { + this.streamId = streamId; + this.chunkIndex = chunkIndex; + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + public void encode(ByteBuf buffer) { + buffer.writeLong(streamId); + buffer.writeInt(chunkIndex); + } + + public static StreamChunkId decode(ByteBuf buffer) { + assert buffer.readableBytes() >= 8 + 4; + long streamId = buffer.readLong(); + int chunkIndex = buffer.readInt(); + return new StreamChunkId(streamId, chunkIndex); + } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, chunkIndex); + } + + @Override + public boolean equals(Object other) { + if (other instanceof StreamChunkId) { + StreamChunkId o = (StreamChunkId) other; + return streamId == o.streamId && chunkIndex == o.chunkIndex; + } + return false; + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("chunkIndex", chunkIndex) + .toString(); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java new file mode 100644 index 0000000000000..b80c15106ecbd --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.protocol.Message; + +/** + * Handles either request or response messages coming off of Netty. A MessageHandler instance + * is associated with a single Netty Channel (though it may have multiple clients on the same + * Channel.) + */ +public abstract class MessageHandler { + /** Handles the receipt of a single message. */ + public abstract void handle(T message); + + /** Invoked when an exception was caught on the Channel. */ + public abstract void exceptionCaught(Throwable cause); + + /** Invoked when the channel this MessageHandler is on has been unregistered. */ + public abstract void channelUnregistered(); +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java new file mode 100644 index 0000000000000..1502b7489e864 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/NoOpRpcHandler.java @@ -0,0 +1,38 @@ +package org.apache.spark.network.server; + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; + +/** An RpcHandler suitable for a client-only TransportContext, which cannot receive RPCs. */ +public class NoOpRpcHandler extends RpcHandler { + private final StreamManager streamManager; + + public NoOpRpcHandler() { + streamManager = new OneForOneStreamManager(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException("Cannot handle messages"); + } + + @Override + public StreamManager getStreamManager() { return streamManager; } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java new file mode 100644 index 0000000000000..a6d390e13f396 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Iterator; +import java.util.Map; +import java.util.Random; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * StreamManager which allows registration of an Iterator<ManagedBuffer>, which are individually + * fetched as chunks by the client. Each registered buffer is one chunk. + */ +public class OneForOneStreamManager extends StreamManager { + private final Logger logger = LoggerFactory.getLogger(OneForOneStreamManager.class); + + private final AtomicLong nextStreamId; + private final Map streams; + + /** State of a single stream. */ + private static class StreamState { + final Iterator buffers; + + // Used to keep track of the index of the buffer that the user has retrieved, just to ensure + // that the caller only requests each chunk one at a time, in order. + int curChunk = 0; + + StreamState(Iterator buffers) { + this.buffers = buffers; + } + } + + public OneForOneStreamManager() { + // For debugging purposes, start with a random stream id to help identifying different streams. + // This does not need to be globally unique, only unique to this class. + nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); + streams = new ConcurrentHashMap(); + } + + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + StreamState state = streams.get(streamId); + if (chunkIndex != state.curChunk) { + throw new IllegalStateException(String.format( + "Received out-of-order chunk index %s (expected %s)", chunkIndex, state.curChunk)); + } else if (!state.buffers.hasNext()) { + throw new IllegalStateException(String.format( + "Requested chunk index beyond end %s", chunkIndex)); + } + state.curChunk += 1; + ManagedBuffer nextChunk = state.buffers.next(); + + if (!state.buffers.hasNext()) { + logger.trace("Removing stream id {}", streamId); + streams.remove(streamId); + } + + return nextChunk; + } + + @Override + public void connectionTerminated(long streamId) { + // Release all remaining buffers. + StreamState state = streams.remove(streamId); + if (state != null && state.buffers != null) { + while (state.buffers.hasNext()) { + state.buffers.next().release(); + } + } + } + + /** + * Registers a stream of ManagedBuffers which are served as individual chunks one at a time to + * callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a + * client connection is closed before the iterator is fully drained, then the remaining buffers + * will all be release()'d. + */ + public long registerStream(Iterator buffers) { + long myStreamId = nextStreamId.getAndIncrement(); + streams.put(myStreamId, new StreamState(buffers)); + return myStreamId; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java new file mode 100644 index 0000000000000..2ba92a40f8b0a --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; + +/** + * Handler for sendRPC() messages sent by {@link org.apache.spark.network.client.TransportClient}s. + */ +public abstract class RpcHandler { + /** + * Receive a single RPC message. Any exception thrown while in this method will be sent back to + * the client in string form as a standard RPC failure. + * + * This method will not be called in parallel for a single TransportClient (i.e., channel). + * + * @param client A channel client which enables the handler to make requests back to the sender + * of this RPC. This will always be the exact same object for a particular channel. + * @param message The serialized bytes of the RPC. + * @param callback Callback which should be invoked exactly once upon success or failure of the + * RPC. + */ + public abstract void receive( + TransportClient client, + byte[] message, + RpcResponseCallback callback); + + /** + * Returns the StreamManager which contains the state about which streams are currently being + * fetched by a TransportClient. + */ + public abstract StreamManager getStreamManager(); + + /** + * Invoked when the connection associated with the given client has been invalidated. + * No further requests will come from this client. + */ + public void connectionTerminated(TransportClient client) { } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java new file mode 100644 index 0000000000000..5a9a14a180c10 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * The StreamManager is used to fetch individual chunks from a stream. This is used in + * {@link TransportRequestHandler} in order to respond to fetchChunk() requests. Creation of the + * stream is outside the scope of the transport layer, but a given stream is guaranteed to be read + * by only one client connection, meaning that getChunk() for a particular stream will be called + * serially and that once the connection associated with the stream is closed, that stream will + * never be used again. + */ +public abstract class StreamManager { + /** + * Called in response to a fetchChunk() request. The returned buffer will be passed as-is to the + * client. A single stream will be associated with a single TCP connection, so this method + * will not be called in parallel for a particular stream. + * + * Chunks may be requested in any order, and requests may be repeated, but it is not required + * that implementations support this behavior. + * + * The returned ManagedBuffer will be release()'d after being written to the network. + * + * @param streamId id of a stream that has been previously registered with the StreamManager. + * @param chunkIndex 0-indexed chunk of the stream that's requested + */ + public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + + /** + * Indicates that the TCP connection that was tied to the given stream has been terminated. After + * this occurs, we are guaranteed not to read from the stream again, so any state can be cleaned + * up. + */ + public void connectionTerminated(long streamId) { } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java new file mode 100644 index 0000000000000..e491367fa4528 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ResponseMessage; +import org.apache.spark.network.util.NettyUtils; + +/** + * The single Transport-level Channel handler which is used for delegating requests to the + * {@link TransportRequestHandler} and responses to the {@link TransportResponseHandler}. + * + * All channels created in the transport layer are bidirectional. When the Client initiates a Netty + * Channel with a RequestMessage (which gets handled by the Server's RequestHandler), the Server + * will produce a ResponseMessage (handled by the Client's ResponseHandler). However, the Server + * also gets a handle on the same Channel, so it may then begin to send RequestMessages to the + * Client. + * This means that the Client also needs a RequestHandler and the Server needs a ResponseHandler, + * for the Client's responses to the Server's requests. + */ +public class TransportChannelHandler extends SimpleChannelInboundHandler { + private final Logger logger = LoggerFactory.getLogger(TransportChannelHandler.class); + + private final TransportClient client; + private final TransportResponseHandler responseHandler; + private final TransportRequestHandler requestHandler; + + public TransportChannelHandler( + TransportClient client, + TransportResponseHandler responseHandler, + TransportRequestHandler requestHandler) { + this.client = client; + this.responseHandler = responseHandler; + this.requestHandler = requestHandler; + } + + public TransportClient getClient() { + return client; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("Exception in connection from " + NettyUtils.getRemoteAddress(ctx.channel()), + cause); + requestHandler.exceptionCaught(cause); + responseHandler.exceptionCaught(cause); + ctx.close(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + try { + requestHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while unregistering channel", e); + } + try { + responseHandler.channelUnregistered(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while unregistering channel", e); + } + super.channelUnregistered(ctx); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Message request) { + if (request instanceof RequestMessage) { + requestHandler.handle((RequestMessage) request); + } else { + responseHandler.handle((ResponseMessage) request); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java new file mode 100644 index 0000000000000..1580180cc17e9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.util.Set; + +import com.google.common.base.Throwables; +import com.google.common.collect.Sets; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.RequestMessage; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.util.NettyUtils; + +/** + * A handler that processes requests from clients and writes chunk data back. Each handler is + * attached to a single Netty channel, and keeps track of which streams have been fetched via this + * channel, in order to clean them up if the channel is terminated (see #channelUnregistered). + * + * The messages should have been processed by the pipeline setup by {@link TransportServer}. + */ +public class TransportRequestHandler extends MessageHandler { + private final Logger logger = LoggerFactory.getLogger(TransportRequestHandler.class); + + /** The Netty channel that this handler is associated with. */ + private final Channel channel; + + /** Client on the same channel allowing us to talk back to the requester. */ + private final TransportClient reverseClient; + + /** Handles all RPC messages. */ + private final RpcHandler rpcHandler; + + /** Returns each chunk part of a stream. */ + private final StreamManager streamManager; + + /** List of all stream ids that have been read on this handler, used for cleanup. */ + private final Set streamIds; + + public TransportRequestHandler( + Channel channel, + TransportClient reverseClient, + RpcHandler rpcHandler) { + this.channel = channel; + this.reverseClient = reverseClient; + this.rpcHandler = rpcHandler; + this.streamManager = rpcHandler.getStreamManager(); + this.streamIds = Sets.newHashSet(); + } + + @Override + public void exceptionCaught(Throwable cause) { + } + + @Override + public void channelUnregistered() { + // Inform the StreamManager that these streams will no longer be read from. + for (long streamId : streamIds) { + streamManager.connectionTerminated(streamId); + } + rpcHandler.connectionTerminated(reverseClient); + } + + @Override + public void handle(RequestMessage request) { + if (request instanceof ChunkFetchRequest) { + processFetchRequest((ChunkFetchRequest) request); + } else if (request instanceof RpcRequest) { + processRpcRequest((RpcRequest) request); + } else { + throw new IllegalArgumentException("Unknown request type: " + request); + } + } + + private void processFetchRequest(final ChunkFetchRequest req) { + final String client = NettyUtils.getRemoteAddress(channel); + streamIds.add(req.streamChunkId.streamId); + + logger.trace("Received req from {} to fetch block {}", client, req.streamChunkId); + + ManagedBuffer buf; + try { + buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); + } catch (Exception e) { + logger.error(String.format( + "Error opening block %s for request from %s", req.streamChunkId, client), e); + respond(new ChunkFetchFailure(req.streamChunkId, Throwables.getStackTraceAsString(e))); + return; + } + + respond(new ChunkFetchSuccess(req.streamChunkId, buf)); + } + + private void processRpcRequest(final RpcRequest req) { + try { + rpcHandler.receive(reverseClient, req.message, new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + respond(new RpcResponse(req.requestId, response)); + } + + @Override + public void onFailure(Throwable e) { + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + } + + /** + * Responds to a single message with some Encodable object. If a failure occurs while sending, + * it will be logged and the channel closed. + */ + private void respond(final Encodable result) { + final String remoteAddress = channel.remoteAddress().toString(); + channel.writeAndFlush(result).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + logger.trace(String.format("Sent result %s to client %s", result, remoteAddress)); + } else { + logger.error(String.format("Error sending result %s to %s; closing connection", + result, remoteAddress), future.cause()); + channel.close(); + } + } + } + ); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java new file mode 100644 index 0000000000000..625c3257d764e --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportServer.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.server; + +import java.io.Closeable; +import java.net.InetSocketAddress; +import java.util.concurrent.TimeUnit; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.PlatformDependent; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.util.IOMode; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Server for the efficient, low-level streaming service. + */ +public class TransportServer implements Closeable { + private final Logger logger = LoggerFactory.getLogger(TransportServer.class); + + private final TransportContext context; + private final TransportConf conf; + + private ServerBootstrap bootstrap; + private ChannelFuture channelFuture; + private int port = -1; + + /** Creates a TransportServer that binds to the given port, or to any available if 0. */ + public TransportServer(TransportContext context, int portToBind) { + this.context = context; + this.conf = context.getConf(); + + init(portToBind); + } + + public int getPort() { + if (port == -1) { + throw new IllegalStateException("Server not initialized"); + } + return port; + } + + private void init(int portToBind) { + + IOMode ioMode = IOMode.valueOf(conf.ioMode()); + EventLoopGroup bossGroup = + NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server"); + EventLoopGroup workerGroup = bossGroup; + + PooledByteBufAllocator allocator = NettyUtils.createPooledByteBufAllocator( + conf.preferDirectBufs(), true /* allowCache */, conf.serverThreads()); + + bootstrap = new ServerBootstrap() + .group(bossGroup, workerGroup) + .channel(NettyUtils.getServerChannelClass(ioMode)) + .option(ChannelOption.ALLOCATOR, allocator) + .childOption(ChannelOption.ALLOCATOR, allocator); + + if (conf.backLog() > 0) { + bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog()); + } + + if (conf.receiveBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_RCVBUF, conf.receiveBuf()); + } + + if (conf.sendBuf() > 0) { + bootstrap.childOption(ChannelOption.SO_SNDBUF, conf.sendBuf()); + } + + bootstrap.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) throws Exception { + context.initializePipeline(ch); + } + }); + + channelFuture = bootstrap.bind(new InetSocketAddress(portToBind)); + channelFuture.syncUninterruptibly(); + + port = ((InetSocketAddress) channelFuture.channel().localAddress()).getPort(); + logger.debug("Shuffle server started on port :" + port); + } + + @Override + public void close() { + if (channelFuture != null) { + // close is a local operation and should finish within milliseconds; timeout just to be safe + channelFuture.channel().close().awaitUninterruptibly(10, TimeUnit.SECONDS); + channelFuture = null; + } + if (bootstrap != null && bootstrap.group() != null) { + bootstrap.group().shutdownGracefully(); + } + if (bootstrap != null && bootstrap.childGroup() != null) { + bootstrap.childGroup().shutdownGracefully(); + } + bootstrap = null; + } + +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java new file mode 100644 index 0000000000000..d944d9da1c7f8 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/ConfigProvider.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +/** + * Provides a mechanism for constructing a {@link TransportConf} using some sort of configuration. + */ +public abstract class ConfigProvider { + /** Obtains the value of the given config, throws NoSuchElementException if it doesn't exist. */ + public abstract String get(String name); + + public String get(String name, String defaultValue) { + try { + return get(name); + } catch (NoSuchElementException e) { + return defaultValue; + } + } + + public int getInt(String name, int defaultValue) { + return Integer.parseInt(get(name, Integer.toString(defaultValue))); + } + + public long getLong(String name, long defaultValue) { + return Long.parseLong(get(name, Long.toString(defaultValue))); + } + + public double getDouble(String name, double defaultValue) { + return Double.parseDouble(get(name, Double.toString(defaultValue))); + } + + public boolean getBoolean(String name, boolean defaultValue) { + return Boolean.parseBoolean(get(name, Boolean.toString(defaultValue))); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/IOMode.java b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java new file mode 100644 index 0000000000000..6b208d95bbfbc --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/IOMode.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +/** + * Selector for which form of low-level IO we should use. + * NIO is always available, while EPOLL is only available on Linux. + * AUTO is used to select EPOLL if it's available, or NIO otherwise. + */ +public enum IOMode { + NIO, EPOLL +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java new file mode 100644 index 0000000000000..bf8a1fc42fc6d --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/JavaUtils.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.nio.ByteBuffer; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.Closeable; +import java.io.File; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import com.google.common.base.Charsets; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * General utilities available in the network package. Many of these are sourced from Spark's + * own Utils, just accessible within this package. + */ +public class JavaUtils { + private static final Logger logger = LoggerFactory.getLogger(JavaUtils.class); + + /** Closes the given object, ignoring IOExceptions. */ + public static void closeQuietly(Closeable closeable) { + try { + if (closeable != null) { + closeable.close(); + } + } catch (IOException e) { + logger.error("IOException should not have been thrown.", e); + } + } + + /** Returns a hash consistent with Spark's Utils.nonNegativeHash(). */ + public static int nonNegativeHash(Object obj) { + if (obj == null) { return 0; } + int hash = obj.hashCode(); + return hash != Integer.MIN_VALUE ? Math.abs(hash) : 0; + } + + /** + * Convert the given string to a byte buffer. The resulting buffer can be + * converted back to the same string through {@link #bytesToString(ByteBuffer)}. + */ + public static ByteBuffer stringToBytes(String s) { + return Unpooled.wrappedBuffer(s.getBytes(Charsets.UTF_8)).nioBuffer(); + } + + /** + * Convert the given byte buffer to a string. The resulting string can be + * converted back to the same byte buffer through {@link #stringToBytes(String)}. + */ + public static String bytesToString(ByteBuffer b) { + return Unpooled.wrappedBuffer(b).toString(Charsets.UTF_8); + } + + /* + * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. + * Throws an exception if deletion is unsuccessful. + */ + public static void deleteRecursively(File file) throws IOException { + if (file == null) { return; } + + if (file.isDirectory() && !isSymlink(file)) { + IOException savedIOException = null; + for (File child : listFilesSafely(file)) { + try { + deleteRecursively(child); + } catch (IOException e) { + // In case of multiple exceptions, only last one will be thrown + savedIOException = e; + } + } + if (savedIOException != null) { + throw savedIOException; + } + } + + boolean deleted = file.delete(); + // Delete can also fail if the file simply did not exist. + if (!deleted && file.exists()) { + throw new IOException("Failed to delete: " + file.getAbsolutePath()); + } + } + + private static File[] listFilesSafely(File file) throws IOException { + if (file.exists()) { + File[] files = file.listFiles(); + if (files == null) { + throw new IOException("Failed to list files for dir: " + file); + } + return files; + } else { + return new File[0]; + } + } + + private static boolean isSymlink(File file) throws IOException { + Preconditions.checkNotNull(file); + File fileInCanonicalDir = null; + if (file.getParent() == null) { + fileInCanonicalDir = file; + } else { + fileInCanonicalDir = new File(file.getParentFile().getCanonicalFile(), file.getName()); + } + return !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()); + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java new file mode 100644 index 0000000000000..57113ed12d414 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +import com.google.common.base.Preconditions; + +/** + * Wraps a {@link InputStream}, limiting the number of bytes which can be read. + * + * This code is from Guava's 14.0 source code, because there is no compatible way to + * use this functionality in both a Guava 11 environment and a Guava >14 environment. + */ +public final class LimitedInputStream extends FilterInputStream { + private long left; + private long mark = -1; + + public LimitedInputStream(InputStream in, long limit) { + super(in); + Preconditions.checkNotNull(in); + Preconditions.checkArgument(limit >= 0, "limit must be non-negative"); + left = limit; + } + @Override public int available() throws IOException { + return (int) Math.min(in.available(), left); + } + // it's okay to mark even if mark isn't supported, as reset won't work + @Override public synchronized void mark(int readLimit) { + in.mark(readLimit); + mark = left; + } + @Override public int read() throws IOException { + if (left == 0) { + return -1; + } + int result = in.read(); + if (result != -1) { + --left; + } + return result; + } + @Override public int read(byte[] b, int off, int len) throws IOException { + if (left == 0) { + return -1; + } + len = (int) Math.min(len, left); + int result = in.read(b, off, len); + if (result != -1) { + left -= result; + } + return result; + } + @Override public synchronized void reset() throws IOException { + if (!in.markSupported()) { + throw new IOException("Mark not supported"); + } + if (mark == -1) { + throw new IOException("Mark not set"); + } + in.reset(); + left = mark; + } + @Override public long skip(long n) throws IOException { + n = Math.min(n, left); + long skipped = in.skip(n); + left -= skipped; + return skipped; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java new file mode 100644 index 0000000000000..2a4b88b64cdc9 --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.lang.reflect.Field; +import java.util.concurrent.ThreadFactory; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.channel.epoll.Epoll; +import io.netty.channel.epoll.EpollEventLoopGroup; +import io.netty.channel.epoll.EpollServerSocketChannel; +import io.netty.channel.epoll.EpollSocketChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.util.internal.PlatformDependent; + +/** + * Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO. + */ +public class NettyUtils { + /** Creates a new ThreadFactory which prefixes each thread with the given name. */ + public static ThreadFactory createThreadFactory(String threadPoolPrefix) { + return new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat(threadPoolPrefix + "-%d") + .build(); + } + + /** Creates a Netty EventLoopGroup based on the IOMode. */ + public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) { + ThreadFactory threadFactory = createThreadFactory(threadPrefix); + + switch (mode) { + case NIO: + return new NioEventLoopGroup(numThreads, threadFactory); + case EPOLL: + return new EpollEventLoopGroup(numThreads, threadFactory); + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct (client) SocketChannel class based on IOMode. */ + public static Class getClientChannelClass(IOMode mode) { + switch (mode) { + case NIO: + return NioSocketChannel.class; + case EPOLL: + return EpollSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** Returns the correct ServerSocketChannel class based on IOMode. */ + public static Class getServerChannelClass(IOMode mode) { + switch (mode) { + case NIO: + return NioServerSocketChannel.class; + case EPOLL: + return EpollServerSocketChannel.class; + default: + throw new IllegalArgumentException("Unknown io mode: " + mode); + } + } + + /** + * Creates a LengthFieldBasedFrameDecoder where the first 8 bytes are the length of the frame. + * This is used before all decoders. + */ + public static ByteToMessageDecoder createFrameDecoder() { + // maxFrameLength = 2G + // lengthFieldOffset = 0 + // lengthFieldLength = 8 + // lengthAdjustment = -8, i.e. exclude the 8 byte length itself + // initialBytesToStrip = 8, i.e. strip out the length field itself + return new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 8, -8, 8); + } + + /** Returns the remote address on the channel or "<remote address>" if none exists. */ + public static String getRemoteAddress(Channel channel) { + if (channel != null && channel.remoteAddress() != null) { + return channel.remoteAddress().toString(); + } + return ""; + } + + /** + * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches + * are disabled for TransportClients because the ByteBufs are allocated by the event loop thread, + * but released by the executor thread rather than the event loop thread. Those thread-local + * caches actually delay the recycling of buffers, leading to larger memory usage. + */ + public static PooledByteBufAllocator createPooledByteBufAllocator( + boolean allowDirectBufs, + boolean allowCache, + int numCores) { + if (numCores == 0) { + numCores = Runtime.getRuntime().availableProcessors(); + } + return new PooledByteBufAllocator( + allowDirectBufs && PlatformDependent.directBufferPreferred(), + Math.min(getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"), numCores), + Math.min(getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"), allowDirectBufs ? numCores : 0), + getPrivateStaticField("DEFAULT_PAGE_SIZE"), + getPrivateStaticField("DEFAULT_MAX_ORDER"), + allowCache ? getPrivateStaticField("DEFAULT_TINY_CACHE_SIZE") : 0, + allowCache ? getPrivateStaticField("DEFAULT_SMALL_CACHE_SIZE") : 0, + allowCache ? getPrivateStaticField("DEFAULT_NORMAL_CACHE_SIZE") : 0 + ); + } + + /** Used to get defaults from Netty's private static fields. */ + private static int getPrivateStaticField(String name) { + try { + Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name); + f.setAccessible(true); + return f.getInt(null); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java new file mode 100644 index 0000000000000..5f20b70678d1e --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/SystemPropertyConfigProvider.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.util.NoSuchElementException; + +import org.apache.spark.network.util.ConfigProvider; + +/** Uses System properties to obtain config values. */ +public class SystemPropertyConfigProvider extends ConfigProvider { + @Override + public String get(String name) { + String value = System.getProperty(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java new file mode 100644 index 0000000000000..1af40acf8b4af --- /dev/null +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +/** + * A central location that tracks all the settings we expose to users. + */ +public class TransportConf { + private final ConfigProvider conf; + + public TransportConf(ConfigProvider conf) { + this.conf = conf; + } + + /** IO mode: nio or epoll */ + public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); } + + /** If true, we will prefer allocating off-heap byte buffers within Netty. */ + public boolean preferDirectBufs() { + return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true); + } + + /** Connect timeout in secs. Default 120 secs. */ + public int connectionTimeoutMs() { + return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000; + } + + /** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */ + public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); } + + /** Number of threads used in the server thread pool. Default to 0, which is 2x#cores. */ + public int serverThreads() { return conf.getInt("spark.shuffle.io.serverThreads", 0); } + + /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ + public int clientThreads() { return conf.getInt("spark.shuffle.io.clientThreads", 0); } + + /** + * Receive buffer size (SO_RCVBUF). + * Note: the optimal size for receive buffer and send buffer should be + * latency * network_bandwidth. + * Assuming latency = 1ms, network_bandwidth = 10Gbps + * buffer size should be ~ 1.25MB + */ + public int receiveBuf() { return conf.getInt("spark.shuffle.io.receiveBuffer", -1); } + + /** Send buffer size (SO_SNDBUF). */ + public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); } + + /** Timeout for a single round trip of SASL token exchange, in milliseconds. */ + public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); } + + /** + * Max number of times we will try IO exceptions (such as connection timeouts) per request. + * If set to 0, we will not do any retries. + */ + public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxRetries", 3); } + + /** + * Time (in milliseconds) that we will wait in order to perform a retry after an IOException. + * Only relevant if maxIORetries > 0. + */ + public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.retryWaitMs", 5000); } + + /** + * Minimum size of a block that we should start using memory map rather than reading in through + * normal IO operations. This prevents Spark from memory mapping very small blocks. In general, + * memory mapping has high overhead for blocks close to or below the page size of the OS. + */ + public int memoryMapBytes() { + return conf.getInt("spark.storage.memoryMapThreshold", 2 * 1024 * 1024); + } + + /** + * Whether to initialize shuffle FileDescriptor lazily or not. If true, file descriptors are + * created only when data is going to be transferred. This can reduce the number of open files. + */ + public boolean lazyFileDescriptor() { + return conf.getBoolean("spark.shuffle.io.lazyFD", true); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java new file mode 100644 index 0000000000000..dfb7740344ed0 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.File; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ChunkFetchIntegrationSuite { + static final long STREAM_ID = 1; + static final int BUFFER_CHUNK_INDEX = 0; + static final int FILE_CHUNK_INDEX = 1; + + static TransportServer server; + static TransportClientFactory clientFactory; + static StreamManager streamManager; + static File testFile; + + static ManagedBuffer bufferChunk; + static ManagedBuffer fileChunk; + + private TransportConf transportConf; + + @BeforeClass + public static void setUp() throws Exception { + int bufSize = 100000; + final ByteBuffer buf = ByteBuffer.allocate(bufSize); + for (int i = 0; i < bufSize; i ++) { + buf.put((byte) i); + } + buf.flip(); + bufferChunk = new NioManagedBuffer(buf); + + testFile = File.createTempFile("shuffle-test-file", "txt"); + testFile.deleteOnExit(); + RandomAccessFile fp = new RandomAccessFile(testFile, "rw"); + byte[] fileContent = new byte[1024]; + new Random().nextBytes(fileContent); + fp.write(fileContent); + fp.close(); + + final TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + fileChunk = new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + + streamManager = new StreamManager() { + @Override + public ManagedBuffer getChunk(long streamId, int chunkIndex) { + assertEquals(STREAM_ID, streamId); + if (chunkIndex == BUFFER_CHUNK_INDEX) { + return new NioManagedBuffer(buf); + } else if (chunkIndex == FILE_CHUNK_INDEX) { + return new FileSegmentManagedBuffer(conf, testFile, 10, testFile.length() - 25); + } else { + throw new IllegalArgumentException("Invalid chunk index: " + chunkIndex); + } + } + }; + RpcHandler handler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + throw new UnsupportedOperationException(); + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + }; + TransportContext context = new TransportContext(conf, handler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + testFile.delete(); + } + + class FetchResult { + public Set successChunks; + public Set failedChunks; + public List buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + private FetchResult fetchChunks(List chunkIndices) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final FetchResult res = new FetchResult(); + res.successChunks = Collections.synchronizedSet(new HashSet()); + res.failedChunks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + ChunkReceivedCallback callback = new ChunkReceivedCallback() { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + buffer.retain(); + res.successChunks.add(chunkIndex); + res.buffers.add(buffer); + sem.release(); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + res.failedChunks.add(chunkIndex); + sem.release(); + } + }; + + for (int chunkIndex : chunkIndices) { + client.fetchChunk(STREAM_ID, chunkIndex, callback); + } + if (!sem.tryAcquire(chunkIndices.size(), 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void fetchBufferChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchFileChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchNonExistentChunk() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(12345)); + assertTrue(res.successChunks.isEmpty()); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertTrue(res.buffers.isEmpty()); + } + + @Test + public void fetchBothChunks() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX, FILE_CHUNK_INDEX)); + assertTrue(res.failedChunks.isEmpty()); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk, fileChunk)); + res.releaseBuffers(); + } + + @Test + public void fetchChunkAndNonExistent() throws Exception { + FetchResult res = fetchChunks(Lists.newArrayList(BUFFER_CHUNK_INDEX, 12345)); + assertEquals(res.successChunks, Sets.newHashSet(BUFFER_CHUNK_INDEX)); + assertEquals(res.failedChunks, Sets.newHashSet(12345)); + assertBufferListsEqual(res.buffers, Lists.newArrayList(bufferChunk)); + res.releaseBuffers(); + } + + private void assertBufferListsEqual(List list0, List list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), list1.get(i)); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java new file mode 100644 index 0000000000000..43dc0cf8c7194 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/ProtocolSuite.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +import org.apache.spark.network.protocol.Message; +import org.apache.spark.network.protocol.StreamChunkId; +import org.apache.spark.network.protocol.ChunkFetchRequest; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcRequest; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.MessageDecoder; +import org.apache.spark.network.protocol.MessageEncoder; +import org.apache.spark.network.util.NettyUtils; + +public class ProtocolSuite { + private void testServerToClient(Message msg) { + EmbeddedChannel serverChannel = new EmbeddedChannel(new MessageEncoder()); + serverChannel.writeOutbound(msg); + + EmbeddedChannel clientChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new MessageDecoder()); + + while (!serverChannel.outboundMessages().isEmpty()) { + clientChannel.writeInbound(serverChannel.readOutbound()); + } + + assertEquals(1, clientChannel.inboundMessages().size()); + assertEquals(msg, clientChannel.readInbound()); + } + + private void testClientToServer(Message msg) { + EmbeddedChannel clientChannel = new EmbeddedChannel(new MessageEncoder()); + clientChannel.writeOutbound(msg); + + EmbeddedChannel serverChannel = new EmbeddedChannel( + NettyUtils.createFrameDecoder(), new MessageDecoder()); + + while (!clientChannel.outboundMessages().isEmpty()) { + serverChannel.writeInbound(clientChannel.readOutbound()); + } + + assertEquals(1, serverChannel.inboundMessages().size()); + assertEquals(msg, serverChannel.readInbound()); + } + + @Test + public void requests() { + testClientToServer(new ChunkFetchRequest(new StreamChunkId(1, 2))); + testClientToServer(new RpcRequest(12345, new byte[0])); + testClientToServer(new RpcRequest(12345, new byte[100])); + } + + @Test + public void responses() { + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(10))); + testServerToClient(new ChunkFetchSuccess(new StreamChunkId(1, 2), new TestManagedBuffer(0))); + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "this is an error")); + testServerToClient(new ChunkFetchFailure(new StreamChunkId(1, 2), "")); + testServerToClient(new RpcResponse(12345, new byte[0])); + testServerToClient(new RpcResponse(12345, new byte[1000])); + testServerToClient(new RpcFailure(0, "this is an error")); + testServerToClient(new RpcFailure(0, "")); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java new file mode 100644 index 0000000000000..64b457b4b3f01 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.base.Charsets; +import com.google.common.collect.Sets; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class RpcIntegrationSuite { + static TransportServer server; + static TransportClientFactory clientFactory; + static RpcHandler rpcHandler; + + @BeforeClass + public static void setUp() throws Exception { + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + rpcHandler = new RpcHandler() { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + String msg = new String(message, Charsets.UTF_8); + String[] parts = msg.split("/"); + if (parts[0].equals("hello")) { + callback.onSuccess(("Hello, " + parts[1] + "!").getBytes(Charsets.UTF_8)); + } else if (parts[0].equals("return error")) { + callback.onFailure(new RuntimeException("Returned: " + parts[1])); + } else if (parts[0].equals("throw error")) { + throw new RuntimeException("Thrown: " + parts[1]); + } + } + + @Override + public StreamManager getStreamManager() { return new OneForOneStreamManager(); } + }; + TransportContext context = new TransportContext(conf, rpcHandler); + server = context.createServer(); + clientFactory = context.createClientFactory(); + } + + @AfterClass + public static void tearDown() { + server.close(); + clientFactory.close(); + } + + class RpcResult { + public Set successMessages; + public Set errorMessages; + } + + private RpcResult sendRPC(String ... commands) throws Exception { + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + final Semaphore sem = new Semaphore(0); + + final RpcResult res = new RpcResult(); + res.successMessages = Collections.synchronizedSet(new HashSet()); + res.errorMessages = Collections.synchronizedSet(new HashSet()); + + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(byte[] message) { + res.successMessages.add(new String(message, Charsets.UTF_8)); + sem.release(); + } + + @Override + public void onFailure(Throwable e) { + res.errorMessages.add(e.getMessage()); + sem.release(); + } + }; + + for (String command : commands) { + client.sendRpc(command.getBytes(Charsets.UTF_8), callback); + } + + if (!sem.tryAcquire(commands.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void singleRPC() throws Exception { + RpcResult res = sendRPC("hello/Aaron"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!")); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void doubleRPC() throws Exception { + RpcResult res = sendRPC("hello/Aaron", "hello/Reynold"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!")); + assertTrue(res.errorMessages.isEmpty()); + } + + @Test + public void returnErrorRPC() throws Exception { + RpcResult res = sendRPC("return error/OK"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK")); + } + + @Test + public void throwErrorRPC() throws Exception { + RpcResult res = sendRPC("throw error/uh-oh"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: uh-oh")); + } + + @Test + public void doubleTrouble() throws Exception { + RpcResult res = sendRPC("return error/OK", "throw error/uh-oh"); + assertTrue(res.successMessages.isEmpty()); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Returned: OK", "Thrown: uh-oh")); + } + + @Test + public void sendSuccessAndFailure() throws Exception { + RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!"); + assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!")); + assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); + } + + private void assertErrorsContain(Set errors, Set contains) { + assertEquals(contains.size(), errors.size()); + + Set remainingErrors = Sets.newHashSet(errors); + for (String contain : contains) { + Iterator it = remainingErrors.iterator(); + boolean foundMatch = false; + while (it.hasNext()) { + if (it.next().contains(contain)) { + it.remove(); + foundMatch = true; + break; + } + } + assertTrue("Could not find error containing " + contain + "; errors: " + errors, foundMatch); + } + + assertTrue(remainingErrors.isEmpty()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java new file mode 100644 index 0000000000000..38113a918f795 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestManagedBuffer.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +import com.google.common.base.Preconditions; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * A ManagedBuffer implementation that contains 0, 1, 2, 3, ..., (len-1). + * + * Used for testing. + */ +public class TestManagedBuffer extends ManagedBuffer { + + private final int len; + private NettyManagedBuffer underlying; + + public TestManagedBuffer(int len) { + Preconditions.checkArgument(len <= Byte.MAX_VALUE); + this.len = len; + byte[] byteArray = new byte[len]; + for (int i = 0; i < len; i ++) { + byteArray[i] = (byte) i; + } + this.underlying = new NettyManagedBuffer(Unpooled.wrappedBuffer(byteArray)); + } + + + @Override + public long size() { + return underlying.size(); + } + + @Override + public ByteBuffer nioByteBuffer() throws IOException { + return underlying.nioByteBuffer(); + } + + @Override + public InputStream createInputStream() throws IOException { + return underlying.createInputStream(); + } + + @Override + public ManagedBuffer retain() { + underlying.retain(); + return this; + } + + @Override + public ManagedBuffer release() { + underlying.release(); + return this; + } + + @Override + public Object convertToNetty() throws IOException { + return underlying.convertToNetty(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof ManagedBuffer) { + try { + ByteBuffer nioBuf = ((ManagedBuffer) other).nioByteBuffer(); + if (nioBuf.remaining() != len) { + return false; + } else { + for (int i = 0; i < len; i ++) { + if (nioBuf.get() != i) { + return false; + } + } + return true; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + return false; + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TestUtils.java b/network/common/src/test/java/org/apache/spark/network/TestUtils.java new file mode 100644 index 0000000000000..56a2b805f154c --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TestUtils.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.net.InetAddress; + +public class TestUtils { + public static String getLocalHost() { + try { + return InetAddress.getLocalHost().getHostAddress(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java new file mode 100644 index 0000000000000..822bef1d81b2a --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import java.io.IOException; +import java.util.concurrent.TimeoutException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class TransportClientFactorySuite { + private TransportConf conf; + private TransportContext context; + private TransportServer server1; + private TransportServer server2; + + @Before + public void setUp() { + conf = new TransportConf(new SystemPropertyConfigProvider()); + RpcHandler rpcHandler = new NoOpRpcHandler(); + context = new TransportContext(conf, rpcHandler); + server1 = context.createServer(); + server2 = context.createServer(); + } + + @After + public void tearDown() { + JavaUtils.closeQuietly(server1); + JavaUtils.closeQuietly(server2); + } + + @Test + public void createAndReuseBlockClients() throws IOException { + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c3.isActive()); + assertTrue(c1 == c2); + assertTrue(c1 != c3); + factory.close(); + } + + @Test + public void neverReturnInactiveClients() throws IOException, InterruptedException { + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + c1.close(); + + long start = System.currentTimeMillis(); + while (c1.isActive() && (System.currentTimeMillis() - start) < 3000) { + Thread.sleep(10); + } + assertFalse(c1.isActive()); + + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + assertFalse(c1 == c2); + assertTrue(c2.isActive()); + factory.close(); + } + + @Test + public void closeBlockClientsWithFactory() throws IOException { + TransportClientFactory factory = context.createClientFactory(); + TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort()); + TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort()); + assertTrue(c1.isActive()); + assertTrue(c2.isActive()); + factory.close(); + assertFalse(c1.isActive()); + assertFalse(c2.isActive()); + } +} diff --git a/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java new file mode 100644 index 0000000000000..17a03ebe88a93 --- /dev/null +++ b/network/common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network; + +import io.netty.channel.local.LocalChannel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportResponseHandler; +import org.apache.spark.network.protocol.ChunkFetchFailure; +import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.RpcFailure; +import org.apache.spark.network.protocol.RpcResponse; +import org.apache.spark.network.protocol.StreamChunkId; + +public class TransportResponseHandlerSuite { + @Test + public void handleSuccessfulFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchSuccess(streamChunkId, new TestManagedBuffer(123))); + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedFetch() { + StreamChunkId streamChunkId = new StreamChunkId(1, 0); + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(streamChunkId, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchFailure(streamChunkId, "some error msg")); + verify(callback, times(1)).onFailure(eq(0), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void clearAllOutstandingRequests() { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + ChunkReceivedCallback callback = mock(ChunkReceivedCallback.class); + handler.addFetchRequest(new StreamChunkId(1, 0), callback); + handler.addFetchRequest(new StreamChunkId(1, 1), callback); + handler.addFetchRequest(new StreamChunkId(1, 2), callback); + assertEquals(3, handler.numOutstandingRequests()); + + handler.handle(new ChunkFetchSuccess(new StreamChunkId(1, 0), new TestManagedBuffer(12))); + handler.exceptionCaught(new Exception("duh duh duhhhh")); + + // should fail both b2 and b3 + verify(callback, times(1)).onSuccess(eq(0), (ManagedBuffer) any()); + verify(callback, times(1)).onFailure(eq(1), (Throwable) any()); + verify(callback, times(1)).onFailure(eq(2), (Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleSuccessfulRPC() { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcResponse(54321, new byte[7])); // should be ignored + assertEquals(1, handler.numOutstandingRequests()); + + byte[] arr = new byte[10]; + handler.handle(new RpcResponse(12345, arr)); + verify(callback, times(1)).onSuccess(eq(arr)); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedRPC() { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.addRpcRequest(12345, callback); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(54321, "uh-oh!")); // should be ignored + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(12345, "oh no")); + verify(callback, times(1)).onFailure((Throwable) any()); + assertEquals(0, handler.numOutstandingRequests()); + } +} diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml new file mode 100644 index 0000000000000..12468567c3aed --- /dev/null +++ b/network/shuffle/pom.xml @@ -0,0 +1,97 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.3.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-shuffle_2.10 + jar + Spark Project Shuffle Streaming Service + http://spark.apache.org/ + + network-shuffle + + + + + + org.apache.spark + spark-network-common_${scala.binary.version} + ${project.version} + + + + + org.slf4j + slf4j-api + provided + + + com.google.guava + guava + provided + + + + + org.apache.spark + spark-network-common_${scala.binary.version} + ${project.version} + test-jar + test + + + junit + junit + test + + + com.novocode + junit-interface + test + + + log4j + log4j + test + + + org.mockito + mockito-all + test + + + org.scalatest + scalatest_${scala.binary.version} + test + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java new file mode 100644 index 0000000000000..7bc91e375371f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.util.TransportConf; + +/** + * Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The + * server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId. + */ +public class SaslClientBootstrap implements TransportClientBootstrap { + private final Logger logger = LoggerFactory.getLogger(SaslClientBootstrap.class); + + private final TransportConf conf; + private final String appId; + private final SecretKeyHolder secretKeyHolder; + + public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) { + this.conf = conf; + this.appId = appId; + this.secretKeyHolder = secretKeyHolder; + } + + /** + * Performs SASL authentication by sending a token, and then proceeding with the SASL + * challenge-response tokens until we either successfully authenticate or throw an exception + * due to mismatch. + */ + @Override + public void doBootstrap(TransportClient client) { + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder); + try { + byte[] payload = saslClient.firstToken(); + + while (!saslClient.isComplete()) { + SaslMessage msg = new SaslMessage(appId, payload); + ByteBuf buf = Unpooled.buffer(msg.encodedLength()); + msg.encode(buf); + + byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout()); + payload = saslClient.response(response); + } + } finally { + try { + // Once authentication is complete, the server will trust all remaining communication. + saslClient.dispose(); + } catch (RuntimeException e) { + logger.error("Error while disposing SASL client", e); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java new file mode 100644 index 0000000000000..cad76ab7aa54e --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslMessage.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** + * Encodes a Sasl-related message which is attempting to authenticate using some credentials tagged + * with the given appId. This appId allows a single SaslRpcHandler to multiplex different + * applications which may be using different sets of credentials. + */ +class SaslMessage implements Encodable { + + /** Serialization tag used to catch incorrect payloads. */ + private static final byte TAG_BYTE = (byte) 0xEA; + + public final String appId; + public final byte[] payload; + + public SaslMessage(String appId, byte[] payload) { + this.appId = appId; + this.payload = payload; + } + + @Override + public int encodedLength() { + return 1 + Encoders.Strings.encodedLength(appId) + Encoders.ByteArrays.encodedLength(payload); + } + + @Override + public void encode(ByteBuf buf) { + buf.writeByte(TAG_BYTE); + Encoders.Strings.encode(buf, appId); + Encoders.ByteArrays.encode(buf, payload); + } + + public static SaslMessage decode(ByteBuf buf) { + if (buf.readByte() != TAG_BYTE) { + throw new IllegalStateException("Expected SaslMessage, received something else" + + " (maybe your client does not have SASL enabled?)"); + } + + String appId = Encoders.Strings.decode(buf); + byte[] payload = Encoders.ByteArrays.decode(buf); + return new SaslMessage(appId, payload); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java new file mode 100644 index 0000000000000..3777a18e33f78 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.concurrent.ConcurrentMap; + +import com.google.common.base.Charsets; +import com.google.common.collect.Maps; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; + +/** + * RPC Handler which performs SASL authentication before delegating to a child RPC handler. + * The delegate will only receive messages if the given connection has been successfully + * authenticated. A connection may be authenticated at most once. + * + * Note that the authentication process consists of multiple challenge-response pairs, each of + * which are individual RPCs. + */ +public class SaslRpcHandler extends RpcHandler { + private final Logger logger = LoggerFactory.getLogger(SaslRpcHandler.class); + + /** RpcHandler we will delegate to for authenticated connections. */ + private final RpcHandler delegate; + + /** Class which provides secret keys which are shared by server and client on a per-app basis. */ + private final SecretKeyHolder secretKeyHolder; + + /** Maps each channel to its SASL authentication state. */ + private final ConcurrentMap channelAuthenticationMap; + + public SaslRpcHandler(RpcHandler delegate, SecretKeyHolder secretKeyHolder) { + this.delegate = delegate; + this.secretKeyHolder = secretKeyHolder; + this.channelAuthenticationMap = Maps.newConcurrentMap(); + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + SparkSaslServer saslServer = channelAuthenticationMap.get(client); + if (saslServer != null && saslServer.isComplete()) { + // Authentication complete, delegate to base handler. + delegate.receive(client, message, callback); + return; + } + + SaslMessage saslMessage = SaslMessage.decode(Unpooled.wrappedBuffer(message)); + + if (saslServer == null) { + // First message in the handshake, setup the necessary state. + saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder); + channelAuthenticationMap.put(client, saslServer); + } + + byte[] response = saslServer.response(saslMessage.payload); + if (saslServer.isComplete()) { + logger.debug("SASL authentication successful for channel {}", client); + } + callback.onSuccess(response); + } + + @Override + public StreamManager getStreamManager() { + return delegate.getStreamManager(); + } + + @Override + public void connectionTerminated(TransportClient client) { + SparkSaslServer saslServer = channelAuthenticationMap.remove(client); + if (saslServer != null) { + saslServer.dispose(); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java new file mode 100644 index 0000000000000..81d5766794688 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SecretKeyHolder.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +/** + * Interface for getting a secret key associated with some application. + */ +public interface SecretKeyHolder { + /** + * Gets an appropriate SASL User for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL user. + */ + String getSaslUser(String appId); + + /** + * Gets an appropriate SASL secret key for the given appId. + * @throws IllegalArgumentException if the given appId is not associated with a SASL secret key. + */ + String getSecretKey(String appId); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java new file mode 100644 index 0000000000000..351c7930a900f --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.lang.Override; +import java.nio.ByteBuffer; +import java.util.concurrent.ConcurrentHashMap; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.util.JavaUtils; + +/** + * A class that manages shuffle secret used by the external shuffle service. + */ +public class ShuffleSecretManager implements SecretKeyHolder { + private final Logger logger = LoggerFactory.getLogger(ShuffleSecretManager.class); + private final ConcurrentHashMap shuffleSecretMap; + + // Spark user used for authenticating SASL connections + // Note that this must match the value in org.apache.spark.SecurityManager + private static final String SPARK_SASL_USER = "sparkSaslUser"; + + public ShuffleSecretManager() { + shuffleSecretMap = new ConcurrentHashMap(); + } + + /** + * Register an application with its secret. + * Executors need to first authenticate themselves with the same secret before + * fetching shuffle files written by other executors in this application. + */ + public void registerApp(String appId, String shuffleSecret) { + if (!shuffleSecretMap.contains(appId)) { + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); + } else { + logger.debug("Application {} already registered", appId); + } + } + + /** + * Register an application with its secret specified as a byte buffer. + */ + public void registerApp(String appId, ByteBuffer shuffleSecret) { + registerApp(appId, JavaUtils.bytesToString(shuffleSecret)); + } + + /** + * Unregister an application along with its secret. + * This is called when the application terminates. + */ + public void unregisterApp(String appId) { + if (shuffleSecretMap.contains(appId)) { + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); + } else { + logger.warn("Attempted to unregister application {} when it is not registered", appId); + } + } + + /** + * Return the Spark user for authenticating SASL connections. + */ + @Override + public String getSaslUser(String appId) { + return SPARK_SASL_USER; + } + + /** + * Return the secret key registered with the given application. + * This key is used to authenticate the executors before they can fetch shuffle files + * written by this application from the external shuffle service. If the specified + * application is not registered, return null. + */ + @Override + public String getSecretKey(String appId) { + return shuffleSecretMap.get(appId); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java new file mode 100644 index 0000000000000..9abad1f30a259 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.RealmChoiceCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslClient; +import javax.security.sasl.SaslException; +import java.io.IOException; + +import com.google.common.base.Throwables; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.apache.spark.network.sasl.SparkSaslServer.*; + +/** + * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. This client initializes the protocol via a + * firstToken, which is then followed by a set of challenges and responses. + */ +public class SparkSaslClient { + private final Logger logger = LoggerFactory.getLogger(SparkSaslClient.class); + + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslClient saslClient; + + public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslClient = Sasl.createSaslClient(new String[] { DIGEST }, null, null, DEFAULT_REALM, + SASL_PROPS, new ClientCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** Used to initiate SASL handshake with server. */ + public synchronized byte[] firstToken() { + if (saslClient != null && saslClient.hasInitialResponse()) { + try { + return saslClient.evaluateChallenge(new byte[0]); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } else { + return new byte[0]; + } + } + + /** Determines whether the authentication exchange has completed. */ + public synchronized boolean isComplete() { + return saslClient != null && saslClient.isComplete(); + } + + /** + * Respond to server's SASL token. + * @param token contains server's SASL token + * @return client's response SASL token + */ + public synchronized byte[] response(byte[] token) { + try { + return saslClient != null ? saslClient.evaluateChallenge(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslClient might be using. + */ + public synchronized void dispose() { + if (saslClient != null) { + try { + saslClient.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslClient = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler + * that works with share secrets. + */ + private class ClientCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL client callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL client callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL client callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof RealmChoiceCallback) { + // ignore (?) + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java new file mode 100644 index 0000000000000..e87b17ead1e1a --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import javax.security.auth.callback.Callback; +import javax.security.auth.callback.CallbackHandler; +import javax.security.auth.callback.NameCallback; +import javax.security.auth.callback.PasswordCallback; +import javax.security.auth.callback.UnsupportedCallbackException; +import javax.security.sasl.AuthorizeCallback; +import javax.security.sasl.RealmCallback; +import javax.security.sasl.Sasl; +import javax.security.sasl.SaslException; +import javax.security.sasl.SaslServer; +import java.io.IOException; +import java.util.Map; + +import com.google.common.base.Charsets; +import com.google.common.base.Preconditions; +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the + * initial state to the "authenticated" state. (It is not a server in the sense of accepting + * connections on some socket.) + */ +public class SparkSaslServer { + private final Logger logger = LoggerFactory.getLogger(SparkSaslServer.class); + + /** + * This is passed as the server name when creating the sasl client/server. + * This could be changed to be configurable in the future. + */ + static final String DEFAULT_REALM = "default"; + + /** + * The authentication mechanism used here is DIGEST-MD5. This could be changed to be + * configurable in the future. + */ + static final String DIGEST = "DIGEST-MD5"; + + /** + * The quality of protection is just "auth". This means that we are doing + * authentication only, we are not supporting integrity or privacy protection of the + * communication channel after authentication. This could be changed to be configurable + * in the future. + */ + static final Map SASL_PROPS = ImmutableMap.builder() + .put(Sasl.QOP, "auth") + .put(Sasl.SERVER_AUTH, "true") + .build(); + + /** Identifier for a certain secret key within the secretKeyHolder. */ + private final String secretKeyId; + private final SecretKeyHolder secretKeyHolder; + private SaslServer saslServer; + + public SparkSaslServer(String secretKeyId, SecretKeyHolder secretKeyHolder) { + this.secretKeyId = secretKeyId; + this.secretKeyHolder = secretKeyHolder; + try { + this.saslServer = Sasl.createSaslServer(DIGEST, null, DEFAULT_REALM, SASL_PROPS, + new DigestCallbackHandler()); + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Determines whether the authentication exchange has completed successfully. + */ + public synchronized boolean isComplete() { + return saslServer != null && saslServer.isComplete(); + } + + /** + * Used to respond to server SASL tokens. + * @param token Server's SASL token + * @return response to send back to the server. + */ + public synchronized byte[] response(byte[] token) { + try { + return saslServer != null ? saslServer.evaluateResponse(token) : new byte[0]; + } catch (SaslException e) { + throw Throwables.propagate(e); + } + } + + /** + * Disposes of any system resources or security-sensitive information the + * SaslServer might be using. + */ + public synchronized void dispose() { + if (saslServer != null) { + try { + saslServer.dispose(); + } catch (SaslException e) { + // ignore + } finally { + saslServer = null; + } + } + } + + /** + * Implementation of javax.security.auth.callback.CallbackHandler for SASL DIGEST-MD5 mechanism. + */ + private class DigestCallbackHandler implements CallbackHandler { + @Override + public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + for (Callback callback : callbacks) { + if (callback instanceof NameCallback) { + logger.trace("SASL server callback: setting username"); + NameCallback nc = (NameCallback) callback; + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } else if (callback instanceof PasswordCallback) { + logger.trace("SASL server callback: setting password"); + PasswordCallback pc = (PasswordCallback) callback; + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } else if (callback instanceof RealmCallback) { + logger.trace("SASL server callback: setting realm"); + RealmCallback rc = (RealmCallback) callback; + rc.setText(rc.getDefaultText()); + } else if (callback instanceof AuthorizeCallback) { + AuthorizeCallback ac = (AuthorizeCallback) callback; + String authId = ac.getAuthenticationID(); + String authzId = ac.getAuthorizationID(); + ac.setAuthorized(authId.equals(authzId)); + if (ac.isAuthorized()) { + ac.setAuthorizedID(authzId); + } + logger.debug("SASL Authorization complete, authorized set to {}", ac.isAuthorized()); + } else { + throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); + } + } + } + } + + /* Encode a byte[] identifier as a Base64-encoded string. */ + public static String encodeIdentifier(String identifier) { + Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); + return Base64.encode(Unpooled.wrappedBuffer(identifier.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8); + } + + /** Encode a password as a base64-encoded char[] array. */ + public static char[] encodePassword(String password) { + Preconditions.checkNotNull(password, "Password cannot be null if SASL is enabled"); + return Base64.encode(Unpooled.wrappedBuffer(password.getBytes(Charsets.UTF_8))) + .toString(Charsets.UTF_8).toCharArray(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java new file mode 100644 index 0000000000000..138fd5389c20a --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.EventListener; + +import org.apache.spark.network.buffer.ManagedBuffer; + +public interface BlockFetchingListener extends EventListener { + /** + * Called once per successfully fetched block. After this call returns, data will be released + * automatically. If the data will be passed to another thread, the receiver should retain() + * and release() the buffer on their own, or copy the data to a new buffer. + */ + void onBlockFetchSuccess(String blockId, ManagedBuffer data); + + /** + * Called at least once per block upon failures. + */ + void onBlockFetchFailure(String blockId, Throwable exception); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java new file mode 100644 index 0000000000000..46ca9708621b9 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.List; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.Lists; +import org.apache.spark.network.util.TransportConf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; + +/** + * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. + * + * Handles registering executors and opening shuffle blocks from them. Shuffle blocks are registered + * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- + * level shuffle block. + */ +public class ExternalShuffleBlockHandler extends RpcHandler { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); + + private final ExternalShuffleBlockManager blockManager; + private final OneForOneStreamManager streamManager; + + public ExternalShuffleBlockHandler(TransportConf conf) { + this(new OneForOneStreamManager(), new ExternalShuffleBlockManager(conf)); + } + + /** Enables mocking out the StreamManager and BlockManager. */ + @VisibleForTesting + ExternalShuffleBlockHandler( + OneForOneStreamManager streamManager, + ExternalShuffleBlockManager blockManager) { + this.streamManager = streamManager; + this.blockManager = blockManager; + } + + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteArray(message); + + if (msgObj instanceof OpenBlocks) { + OpenBlocks msg = (OpenBlocks) msgObj; + List blocks = Lists.newArrayList(); + + for (String blockId : msg.blockIds) { + blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId)); + } + long streamId = streamManager.registerStream(blocks.iterator()); + logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); + callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteArray()); + + } else if (msgObj instanceof RegisterExecutor) { + RegisterExecutor msg = (RegisterExecutor) msgObj; + blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); + callback.onSuccess(new byte[0]); + + } else { + throw new UnsupportedOperationException("Unexpected message: " + msgObj); + } + } + + @Override + public StreamManager getStreamManager() { + return streamManager; + } + + /** + * Removes an application (once it has been terminated), and optionally will clean up any + * local directories associated with the executors of that application in a separate thread. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + blockManager.applicationRemoved(appId, cleanupLocalDirs); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java new file mode 100644 index 0000000000000..dfe0ba0595090 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManager.java @@ -0,0 +1,252 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.DataInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; +import com.google.common.collect.Maps; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Manages converting shuffle BlockIds into physical segments of local files, from a process outside + * of Executors. Each Executor must register its own configuration about where it stores its files + * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated + * from Spark's FileShuffleBlockManager and IndexShuffleBlockManager. + * + * Executors with shuffle file consolidation are not currently supported, as the index is stored in + * the Executor's memory, unlike the IndexShuffleBlockManager. + */ +public class ExternalShuffleBlockManager { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockManager.class); + + // Map containing all registered executors' metadata. + private final ConcurrentMap executors; + + // Single-threaded Java executor used to perform expensive recursive directory deletion. + private final Executor directoryCleaner; + + private final TransportConf conf; + + public ExternalShuffleBlockManager(TransportConf conf) { + // TODO: Give this thread a name. + this(conf, Executors.newSingleThreadExecutor()); + } + + // Allows tests to have more control over when directories are cleaned up. + @VisibleForTesting + ExternalShuffleBlockManager(TransportConf conf, Executor directoryCleaner) { + this.conf = conf; + this.executors = Maps.newConcurrentMap(); + this.directoryCleaner = directoryCleaner; + } + + /** Registers a new Executor with all the configuration we need to find its shuffle files. */ + public void registerExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + AppExecId fullId = new AppExecId(appId, execId); + logger.info("Registered executor {} with {}", fullId, executorInfo); + executors.put(fullId, executorInfo); + } + + /** + * Obtains a FileSegmentManagedBuffer from a shuffle block id. We expect the blockId has the + * format "shuffle_ShuffleId_MapId_ReduceId" (from ShuffleBlockId), and additionally make + * assumptions about how the hash and sort based shuffles store their data. + */ + public ManagedBuffer getBlockData(String appId, String execId, String blockId) { + String[] blockIdParts = blockId.split("_"); + if (blockIdParts.length < 4) { + throw new IllegalArgumentException("Unexpected block id format: " + blockId); + } else if (!blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException("Expected shuffle block id, got: " + blockId); + } + int shuffleId = Integer.parseInt(blockIdParts[1]); + int mapId = Integer.parseInt(blockIdParts[2]); + int reduceId = Integer.parseInt(blockIdParts[3]); + + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + if (executor == null) { + throw new RuntimeException( + String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); + } + + if ("org.apache.spark.shuffle.hash.HashShuffleManager".equals(executor.shuffleManager)) { + return getHashBasedShuffleBlockData(executor, blockId); + } else if ("org.apache.spark.shuffle.sort.SortShuffleManager".equals(executor.shuffleManager)) { + return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + } else { + throw new UnsupportedOperationException( + "Unsupported shuffle manager: " + executor.shuffleManager); + } + } + + /** + * Removes our metadata of all executors registered for the given application, and optionally + * also deletes the local directories associated with the executors of that application in a + * separate thread. + * + * It is not valid to call registerExecutor() for an executor with this appId after invoking + * this method. + */ + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); + Iterator> it = executors.entrySet().iterator(); + while (it.hasNext()) { + Map.Entry entry = it.next(); + AppExecId fullId = entry.getKey(); + final ExecutorShuffleInfo executor = entry.getValue(); + + // Only touch executors associated with the appId that was removed. + if (appId.equals(fullId.appId)) { + it.remove(); + + if (cleanupLocalDirs) { + logger.info("Cleaning up executor {}'s {} local dirs", fullId, executor.localDirs.length); + + // Execute the actual deletion in a different thread, as it may take some time. + directoryCleaner.execute(new Runnable() { + @Override + public void run() { + deleteExecutorDirs(executor.localDirs); + } + }); + } + } + } + } + + /** + * Synchronously deletes each directory one at a time. + * Should be executed in its own thread, as this may take a long time. + */ + private void deleteExecutorDirs(String[] dirs) { + for (String localDir : dirs) { + try { + JavaUtils.deleteRecursively(new File(localDir)); + logger.debug("Successfully cleaned up directory: " + localDir); + } catch (Exception e) { + logger.error("Failed to delete directory: " + localDir, e); + } + } + } + + /** + * Hash-based shuffle data is simply stored as one file per block. + * This logic is from FileShuffleBlockManager. + */ + // TODO: Support consolidated hash shuffle files + private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { + File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); + return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); + } + + /** + * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file + * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockManager, + * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. + */ + private ManagedBuffer getSortBasedShuffleBlockData( + ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.index"); + + DataInputStream in = null; + try { + in = new DataInputStream(new FileInputStream(indexFile)); + in.skipBytes(reduceId * 8); + long offset = in.readLong(); + long nextOffset = in.readLong(); + return new FileSegmentManagedBuffer( + conf, + getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0.data"), + offset, + nextOffset - offset); + } catch (IOException e) { + throw new RuntimeException("Failed to open file: " + indexFile, e); + } finally { + if (in != null) { + JavaUtils.closeQuietly(in); + } + } + } + + /** + * Hashes a filename into the corresponding local directory, in a manner consistent with + * Spark's DiskBlockManager.getFile(). + */ + @VisibleForTesting + static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { + int hash = JavaUtils.nonNegativeHash(filename); + String localDir = localDirs[hash % localDirs.length]; + int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; + return new File(new File(localDir, String.format("%02x", subDirId)), filename); + } + + /** Simply encodes an executor's full ID, which is appId + execId. */ + private static class AppExecId { + final String appId; + final String execId; + + private AppExecId(String appId, String execId) { + this.appId = appId; + this.execId = execId; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + AppExecId appExecId = (AppExecId) o; + return Objects.equal(appId, appExecId.appId) && Objects.equal(execId, appExecId.execId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .toString(); + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java new file mode 100644 index 0000000000000..6e8018b723dc6 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.List; + +import com.google.common.collect.Lists; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.sasl.SaslClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.util.TransportConf; + +/** + * Client for reading shuffle blocks which points to an external (outside of executor) server. + * This is instead of reading shuffle blocks directly from other executors (via + * BlockTransferService), which has the downside of losing the shuffle data if we lose the + * executors. + */ +public class ExternalShuffleClient extends ShuffleClient { + private final Logger logger = LoggerFactory.getLogger(ExternalShuffleClient.class); + + private final TransportConf conf; + private final boolean saslEnabled; + private final SecretKeyHolder secretKeyHolder; + + private TransportClientFactory clientFactory; + private String appId; + + /** + * Creates an external shuffle client, with SASL optionally enabled. If SASL is not enabled, + * then secretKeyHolder may be null. + */ + public ExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean saslEnabled) { + this.conf = conf; + this.secretKeyHolder = secretKeyHolder; + this.saslEnabled = saslEnabled; + } + + @Override + public void init(String appId) { + this.appId = appId; + TransportContext context = new TransportContext(conf, new NoOpRpcHandler()); + List bootstraps = Lists.newArrayList(); + if (saslEnabled) { + bootstraps.add(new SaslClientBootstrap(conf, appId, secretKeyHolder)); + } + clientFactory = context.createClientFactory(bootstraps); + } + + @Override + public void fetchBlocks( + final String host, + final int port, + final String execId, + String[] blockIds, + BlockFetchingListener listener) { + assert appId != null : "Called before init()"; + logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); + try { + RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = + new RetryingBlockFetcher.BlockFetchStarter() { + @Override + public void createAndStart(String[] blockIds, BlockFetchingListener listener) + throws IOException { + TransportClient client = clientFactory.createClient(host, port); + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener).start(); + } + }; + + int maxRetries = conf.maxIORetries(); + if (maxRetries > 0) { + // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's + // a bug in this code. We should remove the if statement once we're sure of the stability. + new RetryingBlockFetcher(conf, blockFetchStarter, blockIds, listener).start(); + } else { + blockFetchStarter.createAndStart(blockIds, listener); + } + } catch (Exception e) { + logger.error("Exception while beginning fetchBlocks", e); + for (String blockId : blockIds) { + listener.onBlockFetchFailure(blockId, e); + } + } + } + + /** + * Registers this executor with an external shuffle server. This registration is required to + * inform the shuffle server about where and how we store our shuffle files. + * + * @param host Host of shuffle server. + * @param port Port of shuffle server. + * @param execId This Executor's id. + * @param executorInfo Contains all info necessary for the service to find our shuffle files. + */ + public void registerWithShuffleServer( + String host, + int port, + String execId, + ExecutorShuffleInfo executorInfo) throws IOException { + assert appId != null : "Called before init()"; + TransportClient client = clientFactory.createClient(host, port); + byte[] registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteArray(); + client.sendRpcSync(registerMessage, 5000 /* timeoutMs */); + } + + @Override + public void close() { + clientFactory.close(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java new file mode 100644 index 0000000000000..8ed2e0b39ad23 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.util.Arrays; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.util.JavaUtils; + +/** + * Simple wrapper on top of a TransportClient which interprets each chunk as a whole block, and + * invokes the BlockFetchingListener appropriately. This class is agnostic to the actual RPC + * handler, as long as there is a single "open blocks" message which returns a ShuffleStreamHandle, + * and Java serialization is used. + * + * Note that this typically corresponds to a + * {@link org.apache.spark.network.server.OneForOneStreamManager} on the server side. + */ +public class OneForOneBlockFetcher { + private final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); + + private final TransportClient client; + private final OpenBlocks openMessage; + private final String[] blockIds; + private final BlockFetchingListener listener; + private final ChunkReceivedCallback chunkCallback; + + private StreamHandle streamHandle = null; + + public OneForOneBlockFetcher( + TransportClient client, + String appId, + String execId, + String[] blockIds, + BlockFetchingListener listener) { + this.client = client; + this.openMessage = new OpenBlocks(appId, execId, blockIds); + this.blockIds = blockIds; + this.listener = listener; + this.chunkCallback = new ChunkCallback(); + } + + /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ + private class ChunkCallback implements ChunkReceivedCallback { + @Override + public void onSuccess(int chunkIndex, ManagedBuffer buffer) { + // On receipt of a chunk, pass it upwards as a block. + listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + } + + @Override + public void onFailure(int chunkIndex, Throwable e) { + // On receipt of a failure, fail every block from chunkIndex onwards. + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); + failRemainingBlocks(remainingBlockIds, e); + } + } + + /** + * Begins the fetching process, calling the listener with every block fetched. + * The given message will be serialized with the Java serializer, and the RPC must return a + * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. + */ + public void start() { + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + + client.sendRpc(openMessage.toByteArray(), new RpcResponseCallback() { + @Override + public void onSuccess(byte[] response) { + try { + streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response); + logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); + + // Immediately request all chunks -- we expect that the total size of the request is + // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. + for (int i = 0; i < streamHandle.numChunks; i++) { + client.fetchChunk(streamHandle.streamId, i, chunkCallback); + } + } catch (Exception e) { + logger.error("Failed while starting block fetches after success", e); + failRemainingBlocks(blockIds, e); + } + } + + @Override + public void onFailure(Throwable e) { + logger.error("Failed while starting block fetches", e); + failRemainingBlocks(blockIds, e); + } + }); + } + + /** Invokes the "onBlockFetchFailure" callback for every listed block id. */ + private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { + for (String blockId : failedBlockIds) { + try { + listener.onBlockFetchFailure(blockId, e); + } catch (Exception e2) { + logger.error("Error in block fetch failure callback", e2); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java new file mode 100644 index 0000000000000..f8a1a266863bb --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.Uninterruptibles; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * Wraps another BlockFetcher with the ability to automatically retry fetches which fail due to + * IOExceptions, which we hope are due to transient network conditions. + * + * This fetcher provides stronger guarantees regarding the parent BlockFetchingListener. In + * particular, the listener will be invoked exactly once per blockId, with a success or failure. + */ +public class RetryingBlockFetcher { + + /** + * Used to initiate the first fetch for all blocks, and subsequently for retrying the fetch on any + * remaining blocks. + */ + public static interface BlockFetchStarter { + /** + * Creates a new BlockFetcher to fetch the given block ids which may do some synchronous + * bootstrapping followed by fully asynchronous block fetching. + * The BlockFetcher must eventually invoke the Listener on every input blockId, or else this + * method must throw an exception. + * + * This method should always attempt to get a new TransportClient from the + * {@link org.apache.spark.network.client.TransportClientFactory} in order to fix connection + * issues. + */ + void createAndStart(String[] blockIds, BlockFetchingListener listener) throws IOException; + } + + /** Shared executor service used for waiting and retrying. */ + private static final ExecutorService executorService = Executors.newCachedThreadPool( + NettyUtils.createThreadFactory("Block Fetch Retry")); + + private final Logger logger = LoggerFactory.getLogger(RetryingBlockFetcher.class); + + /** Used to initiate new Block Fetches on our remaining blocks. */ + private final BlockFetchStarter fetchStarter; + + /** Parent listener which we delegate all successful or permanently failed block fetches to. */ + private final BlockFetchingListener listener; + + /** Max number of times we are allowed to retry. */ + private final int maxRetries; + + /** Milliseconds to wait before each retry. */ + private final int retryWaitTime; + + // NOTE: + // All of our non-final fields are synchronized under 'this' and should only be accessed/mutated + // while inside a synchronized block. + /** Number of times we've attempted to retry so far. */ + private int retryCount = 0; + + /** + * Set of all block ids which have not been fetched successfully or with a non-IO Exception. + * A retry involves requesting every outstanding block. Note that since this is a LinkedHashSet, + * input ordering is preserved, so we always request blocks in the same order the user provided. + */ + private final LinkedHashSet outstandingBlocksIds; + + /** + * The BlockFetchingListener that is active with our current BlockFetcher. + * When we start a retry, we immediately replace this with a new Listener, which causes all any + * old Listeners to ignore all further responses. + */ + private RetryingBlockFetchListener currentListener; + + public RetryingBlockFetcher( + TransportConf conf, + BlockFetchStarter fetchStarter, + String[] blockIds, + BlockFetchingListener listener) { + this.fetchStarter = fetchStarter; + this.listener = listener; + this.maxRetries = conf.maxIORetries(); + this.retryWaitTime = conf.ioRetryWaitTime(); + this.outstandingBlocksIds = Sets.newLinkedHashSet(); + Collections.addAll(outstandingBlocksIds, blockIds); + this.currentListener = new RetryingBlockFetchListener(); + } + + /** + * Initiates the fetch of all blocks provided in the constructor, with possible retries in the + * event of transient IOExceptions. + */ + public void start() { + fetchAllOutstanding(); + } + + /** + * Fires off a request to fetch all blocks that have not been fetched successfully or permanently + * failed (i.e., by a non-IOException). + */ + private void fetchAllOutstanding() { + // Start by retrieving our shared state within a synchronized block. + String[] blockIdsToFetch; + int numRetries; + RetryingBlockFetchListener myListener; + synchronized (this) { + blockIdsToFetch = outstandingBlocksIds.toArray(new String[outstandingBlocksIds.size()]); + numRetries = retryCount; + myListener = currentListener; + } + + // Now initiate the fetch on all outstanding blocks, possibly initiating a retry if that fails. + try { + fetchStarter.createAndStart(blockIdsToFetch, myListener); + } catch (Exception e) { + logger.error(String.format("Exception while beginning fetch of %s outstanding blocks %s", + blockIdsToFetch.length, numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e); + + if (shouldRetry(e)) { + initiateRetry(); + } else { + for (String bid : blockIdsToFetch) { + listener.onBlockFetchFailure(bid, e); + } + } + } + } + + /** + * Lightweight method which initiates a retry in a different thread. The retry will involve + * calling fetchAllOutstanding() after a configured wait time. + */ + private synchronized void initiateRetry() { + retryCount += 1; + currentListener = new RetryingBlockFetchListener(); + + logger.info("Retrying fetch ({}/{}) for {} outstanding blocks after {} ms", + retryCount, maxRetries, outstandingBlocksIds.size(), retryWaitTime); + + executorService.submit(new Runnable() { + @Override + public void run() { + Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); + fetchAllOutstanding(); + } + }); + } + + /** + * Returns true if we should retry due a block fetch failure. We will retry if and only if + * the exception was an IOException and we haven't retried 'maxRetries' times already. + */ + private synchronized boolean shouldRetry(Throwable e) { + boolean isIOException = e instanceof IOException + || (e.getCause() != null && e.getCause() instanceof IOException); + boolean hasRemainingRetries = retryCount < maxRetries; + return isIOException && hasRemainingRetries; + } + + /** + * Our RetryListener intercepts block fetch responses and forwards them to our parent listener. + * Note that in the event of a retry, we will immediately replace the 'currentListener' field, + * indicating that any responses from non-current Listeners should be ignored. + */ + private class RetryingBlockFetchListener implements BlockFetchingListener { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + // We will only forward this success message to our parent listener if this block request is + // outstanding and we are still the active listener. + boolean shouldForwardSuccess = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + outstandingBlocksIds.remove(blockId); + shouldForwardSuccess = true; + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardSuccess) { + listener.onBlockFetchSuccess(blockId, data); + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + // We will only forward this failure to our parent listener if this block request is + // outstanding, we are still the active listener, AND we cannot retry the fetch. + boolean shouldForwardFailure = false; + synchronized (RetryingBlockFetcher.this) { + if (this == currentListener && outstandingBlocksIds.contains(blockId)) { + if (shouldRetry(exception)) { + initiateRetry(); + } else { + logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)", + blockId, retryCount), exception); + outstandingBlocksIds.remove(blockId); + shouldForwardFailure = true; + } + } + } + + // Now actually invoke the parent listener, outside of the synchronized block. + if (shouldForwardFailure) { + listener.onBlockFetchFailure(blockId, exception); + } + } + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java new file mode 100644 index 0000000000000..f72ab40690d0d --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.Closeable; + +/** Provides an interface for reading shuffle files, either from an Executor or external service. */ +public abstract class ShuffleClient implements Closeable { + + /** + * Initializes the ShuffleClient, specifying this Executor's appId. + * Must be called before any other method on the ShuffleClient. + */ + public void init(String appId) { } + + /** + * Fetch a sequence of blocks from a remote node asynchronously, + * + * Note that this API takes a sequence so the implementation can batch requests, and does not + * return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as + * the data of a block is fetched, rather than waiting for all blocks to be fetched. + */ + public abstract void fetchBlocks( + String host, + int port, + String execId, + String[] blockIds, + BlockFetchingListener listener); +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java new file mode 100644 index 0000000000000..b4b13b8a6ef5d --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +import org.apache.spark.network.protocol.Encodable; + +/** + * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or + * by Spark's NettyBlockTransferService. + * + * At a high level: + * - OpenBlock is handled by both services, but only services shuffle files for the external + * shuffle service. It returns a StreamHandle. + * - UploadBlock is only handled by the NettyBlockTransferService. + * - RegisterExecutor is only handled by the external shuffle service. + */ +public abstract class BlockTransferMessage implements Encodable { + protected abstract Type type(); + + /** Preceding every serialized message is its type, which allows us to deserialize it. */ + public static enum Type { + OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3); + + private final byte id; + + private Type(int id) { + assert id < 128 : "Cannot have more than 128 message types"; + this.id = (byte) id; + } + + public byte id() { return id; } + } + + // NB: Java does not support static methods in interfaces, so we must put this in a static class. + public static class Decoder { + /** Deserializes the 'type' byte followed by the message itself. */ + public static BlockTransferMessage fromByteArray(byte[] msg) { + ByteBuf buf = Unpooled.wrappedBuffer(msg); + byte type = buf.readByte(); + switch (type) { + case 0: return OpenBlocks.decode(buf); + case 1: return UploadBlock.decode(buf); + case 2: return RegisterExecutor.decode(buf); + case 3: return StreamHandle.decode(buf); + default: throw new IllegalArgumentException("Unknown message type: " + type); + } + } + } + + /** Serializes the 'type' byte followed by the message itself. */ + public byte[] toByteArray() { + ByteBuf buf = Unpooled.buffer(encodedLength()); + buf.writeByte(type().id); + encode(buf); + assert buf.writableBytes() == 0 : "Writable bytes remain: " + buf.writableBytes(); + return buf.array(); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java new file mode 100644 index 0000000000000..cadc8e8369c6a --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encodable; +import org.apache.spark.network.protocol.Encoders; + +/** Contains all configuration necessary for locating the shuffle files of an executor. */ +public class ExecutorShuffleInfo implements Encodable { + /** The base set of local directories that the executor stores its shuffle files in. */ + public final String[] localDirs; + /** Number of subdirectories created within each localDir. */ + public final int subDirsPerLocalDir; + /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ + public final String shuffleManager; + + public ExecutorShuffleInfo(String[] localDirs, int subDirsPerLocalDir, String shuffleManager) { + this.localDirs = localDirs; + this.subDirsPerLocalDir = subDirsPerLocalDir; + this.shuffleManager = shuffleManager; + } + + @Override + public int hashCode() { + return Objects.hashCode(subDirsPerLocalDir, shuffleManager) * 41 + Arrays.hashCode(localDirs); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("localDirs", Arrays.toString(localDirs)) + .add("subDirsPerLocalDir", subDirsPerLocalDir) + .add("shuffleManager", shuffleManager) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof ExecutorShuffleInfo) { + ExecutorShuffleInfo o = (ExecutorShuffleInfo) other; + return Arrays.equals(localDirs, o.localDirs) + && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir) + && Objects.equal(shuffleManager, o.shuffleManager); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.StringArrays.encodedLength(localDirs) + + 4 // int + + Encoders.Strings.encodedLength(shuffleManager); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.StringArrays.encode(buf, localDirs); + buf.writeInt(subDirsPerLocalDir); + Encoders.Strings.encode(buf, shuffleManager); + } + + public static ExecutorShuffleInfo decode(ByteBuf buf) { + String[] localDirs = Encoders.StringArrays.decode(buf); + int subDirsPerLocalDir = buf.readInt(); + String shuffleManager = Encoders.Strings.decode(buf); + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java new file mode 100644 index 0000000000000..62fce9b0d16cd --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** Request to read a set of blocks. Returns {@link StreamHandle}. */ +public class OpenBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String[] blockIds; + + public OpenBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + protected Type type() { return Type.OPEN_BLOCKS; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenBlocks) { + OpenBlocks o = (OpenBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.StringArrays.encodedLength(blockIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.StringArrays.encode(buf, blockIds); + } + + public static OpenBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + return new OpenBlocks(appId, execId, blockIds); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java new file mode 100644 index 0000000000000..7eb4385044077 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutor.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * Initial registration message between an executor and its local shuffle server. + * Returns nothing (empty bye array). + */ +public class RegisterExecutor extends BlockTransferMessage { + public final String appId; + public final String execId; + public final ExecutorShuffleInfo executorInfo; + + public RegisterExecutor( + String appId, + String execId, + ExecutorShuffleInfo executorInfo) { + this.appId = appId; + this.execId = execId; + this.executorInfo = executorInfo; + } + + @Override + protected Type type() { return Type.REGISTER_EXECUTOR; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, executorInfo); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("executorInfo", executorInfo) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RegisterExecutor) { + RegisterExecutor o = (RegisterExecutor) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(executorInfo, o.executorInfo); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + executorInfo.encodedLength(); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + executorInfo.encode(buf); + } + + public static RegisterExecutor decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + ExecutorShuffleInfo executorShuffleInfo = ExecutorShuffleInfo.decode(buf); + return new RegisterExecutor(appId, execId, executorShuffleInfo); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java new file mode 100644 index 0000000000000..bc9daa6158ba3 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/StreamHandle.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * Identifier for a fixed number of chunks to read from a stream created by an "open blocks" + * message. This is used by {@link org.apache.spark.network.shuffle.OneForOneBlockFetcher}. + */ +public class StreamHandle extends BlockTransferMessage { + public final long streamId; + public final int numChunks; + + public StreamHandle(long streamId, int numChunks) { + this.streamId = streamId; + this.numChunks = numChunks; + } + + @Override + protected Type type() { return Type.STREAM_HANDLE; } + + @Override + public int hashCode() { + return Objects.hashCode(streamId, numChunks); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("streamId", streamId) + .add("numChunks", numChunks) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof StreamHandle) { + StreamHandle o = (StreamHandle) other; + return Objects.equal(streamId, o.streamId) + && Objects.equal(numChunks, o.numChunks); + } + return false; + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeLong(streamId); + buf.writeInt(numChunks); + } + + public static StreamHandle decode(ByteBuf buf) { + long streamId = buf.readLong(); + int numChunks = buf.readInt(); + return new StreamHandle(streamId, numChunks); + } +} diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java new file mode 100644 index 0000000000000..0b23e112bd512 --- /dev/null +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadBlock.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + + +/** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */ +public class UploadBlock extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String blockId; + // TODO: StorageLevel is serialized separately in here because StorageLevel is not available in + // this package. We should avoid this hack. + public final byte[] metadata; + public final byte[] blockData; + + /** + * @param metadata Meta-information about block, typically StorageLevel. + * @param blockData The actual block's bytes. + */ + public UploadBlock( + String appId, + String execId, + String blockId, + byte[] metadata, + byte[] blockData) { + this.appId = appId; + this.execId = execId; + this.blockId = blockId; + this.metadata = metadata; + this.blockData = blockData; + } + + @Override + protected Type type() { return Type.UPLOAD_BLOCK; } + + @Override + public int hashCode() { + int objectsHashCode = Objects.hashCode(appId, execId, blockId); + return (objectsHashCode * 41 + Arrays.hashCode(metadata)) * 41 + Arrays.hashCode(blockData); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockId", blockId) + .add("metadata size", metadata.length) + .add("block size", blockData.length) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadBlock) { + UploadBlock o = (UploadBlock) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(blockId, o.blockId) + && Arrays.equals(metadata, o.metadata) + && Arrays.equals(blockData, o.blockData); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(blockId) + + Encoders.ByteArrays.encodedLength(metadata) + + Encoders.ByteArrays.encodedLength(blockData); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, blockId); + Encoders.ByteArrays.encode(buf, metadata); + Encoders.ByteArrays.encode(buf, blockData); + } + + public static UploadBlock decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String blockId = Encoders.Strings.decode(buf); + byte[] metadata = Encoders.ByteArrays.decode(buf); + byte[] blockData = Encoders.ByteArrays.decode(buf); + return new UploadBlock(appId, execId, blockId, metadata, blockData); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java new file mode 100644 index 0000000000000..d25283e46ef96 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.io.IOException; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class SaslIntegrationSuite { + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + static TransportContext context; + + TransportClientFactory clientFactory; + + /** Provides a secret key holder which always returns the given secret key. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + + private final String secretKey; + + TestSecretKeyHolder(String secretKey) { + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + @Override + public String getSecretKey(String appId) { + return secretKey; + } + } + + + @BeforeClass + public static void beforeAll() throws IOException { + SecretKeyHolder secretKeyHolder = new TestSecretKeyHolder("good-key"); + SaslRpcHandler handler = new SaslRpcHandler(new TestRpcHandler(), secretKeyHolder); + conf = new TransportConf(new SystemPropertyConfigProvider()); + context = new TransportContext(conf, handler); + server = context.createServer(); + } + + + @AfterClass + public static void afterAll() { + server.close(); + } + + @After + public void afterEach() { + if (clientFactory != null) { + clientFactory.close(); + clientFactory = null; + } + } + + @Test + public void testGoodClient() throws IOException { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key")))); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + String msg = "Hello, World!"; + byte[] resp = client.sendRpcSync(msg.getBytes(), 1000); + assertEquals(msg, new String(resp)); // our rpc handler should just return the given msg + } + + @Test + public void testBadClient() { + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key")))); + + try { + // Bootstrap should fail on startup. + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + @Test + public void testNoSaslClient() throws IOException { + clientFactory = context.createClientFactory( + Lists.newArrayList()); + + TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + try { + client.sendRpcSync(new byte[13], 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage")); + } + + try { + // Guessing the right tag byte doesn't magically get you in... + client.sendRpcSync(new byte[] { (byte) 0xEA }, 1000); + fail("Should have failed"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException")); + } + } + + @Test + public void testNoSaslServer() { + RpcHandler handler = new TestRpcHandler(); + TransportContext context = new TransportContext(conf, handler); + clientFactory = context.createClientFactory( + Lists.newArrayList( + new SaslClientBootstrap(conf, "app-id", new TestSecretKeyHolder("key")))); + TransportServer server = context.createServer(); + try { + clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation")); + } finally { + server.close(); + } + } + + /** RPC handler which simply responds with the message it received. */ + public static class TestRpcHandler extends RpcHandler { + @Override + public void receive(TransportClient client, byte[] message, RpcResponseCallback callback) { + callback.onSuccess(message); + } + + @Override + public StreamManager getStreamManager() { + return new OneForOneStreamManager(); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java new file mode 100644 index 0000000000000..67a07f38eb5a0 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import org.junit.Test; + +import static org.junit.Assert.*; + +/** + * Jointly tests SparkSaslClient and SparkSaslServer, as both are black boxes. + */ +public class SparkSaslSuite { + + /** Provides a secret key holder which returns secret key == appId */ + private SecretKeyHolder secretKeyHolder = new SecretKeyHolder() { + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + return appId; + } + }; + + @Test + public void testMatching() { + SparkSaslClient client = new SparkSaslClient("shared-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("shared-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + assertTrue(server.isComplete()); + + // Disposal should invalidate + server.dispose(); + assertFalse(server.isComplete()); + client.dispose(); + assertFalse(client.isComplete()); + } + + + @Test + public void testNonMatching() { + SparkSaslClient client = new SparkSaslClient("my-secret", secretKeyHolder); + SparkSaslServer server = new SparkSaslServer("your-secret", secretKeyHolder); + + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + + byte[] clientMessage = client.firstToken(); + + try { + while (!client.isComplete()) { + clientMessage = client.response(server.response(clientMessage)); + } + fail("Should not have completed"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("Mismatched response")); + assertFalse(client.isComplete()); + assertFalse(server.isComplete()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java new file mode 100644 index 0000000000000..d65de9ca550a3 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.shuffle.protocol.*; + +/** Verifies that all BlockTransferMessages can be serialized correctly. */ +public class BlockTransferMessagesSuite { + @Test + public void serializeOpenShuffleBlocks() { + checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); + checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( + new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); + checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, + new byte[] { 4, 5, 6, 7} )); + checkSerializeDeserialize(new StreamHandle(12345, 16)); + } + + private void checkSerializeDeserialize(BlockTransferMessage msg) { + BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteArray(msg.toByteArray()); + assertEquals(msg, msg2); + assertEquals(msg.hashCode(), msg2.hashCode()); + assertEquals(msg.toString(), msg2.toString()); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java new file mode 100644 index 0000000000000..3f9fe1681cf27 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import static org.junit.Assert.*; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.RegisterExecutor; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.network.shuffle.protocol.UploadBlock; + +public class ExternalShuffleBlockHandlerSuite { + TransportClient client = mock(TransportClient.class); + + OneForOneStreamManager streamManager; + ExternalShuffleBlockManager blockManager; + RpcHandler handler; + + @Before + public void beforeEach() { + streamManager = mock(OneForOneStreamManager.class); + blockManager = mock(ExternalShuffleBlockManager.class); + handler = new ExternalShuffleBlockHandler(streamManager, blockManager); + } + + @Test + public void testRegisterExecutor() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort"); + byte[] registerMessage = new RegisterExecutor("app0", "exec1", config).toByteArray(); + handler.receive(client, registerMessage, callback); + verify(blockManager, times(1)).registerExecutor("app0", "exec1", config); + + verify(callback, times(1)).onSuccess((byte[]) any()); + verify(callback, never()).onFailure((Throwable) any()); + } + + @SuppressWarnings("unchecked") + @Test + public void testOpenShuffleBlocks() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); + ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + when(blockManager.getBlockData("app0", "exec1", "b0")).thenReturn(block0Marker); + when(blockManager.getBlockData("app0", "exec1", "b1")).thenReturn(block1Marker); + byte[] openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }).toByteArray(); + handler.receive(client, openBlocks, callback); + verify(blockManager, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockManager, times(1)).getBlockData("app0", "exec1", "b1"); + + ArgumentCaptor response = ArgumentCaptor.forClass(byte[].class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure((Throwable) any()); + + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteArray(response.getValue()); + assertEquals(2, handle.numChunks); + + ArgumentCaptor stream = ArgumentCaptor.forClass(Iterator.class); + verify(streamManager, times(1)).registerStream(stream.capture()); + Iterator buffers = (Iterator) stream.getValue(); + assertEquals(block0Marker, buffers.next()); + assertEquals(block1Marker, buffers.next()); + assertFalse(buffers.hasNext()); + } + + @Test + public void testBadMessages() { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + + byte[] unserializableMsg = new byte[] { 0x12, 0x34, 0x56 }; + try { + handler.receive(client, unserializableMsg, callback); + fail("Should have thrown"); + } catch (Exception e) { + // pass + } + + byte[] unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1], new byte[2]).toByteArray(); + try { + handler.receive(client, unexpectedMsg, callback); + fail("Should have thrown"); + } catch (UnsupportedOperationException e) { + // pass + } + + verify(callback, never()).onSuccess((byte[]) any()); + verify(callback, never()).onFailure((Throwable) any()); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java new file mode 100644 index 0000000000000..dad6428a836fc --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockManagerSuite.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; + +import com.google.common.io.CharStreams; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class ExternalShuffleBlockManagerSuite { + static String sortBlock0 = "Hello!"; + static String sortBlock1 = "World!"; + + static String hashBlock0 = "Elementary"; + static String hashBlock1 = "Tabular"; + + static TestShuffleDataContext dataContext; + + static TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + @BeforeClass + public static void beforeAll() throws IOException { + dataContext = new TestShuffleDataContext(2, 5); + + dataContext.create(); + // Write some sort and hash data. + dataContext.insertSortShuffleData(0, 0, + new byte[][] { sortBlock0.getBytes(), sortBlock1.getBytes() } ); + dataContext.insertHashShuffleData(1, 0, + new byte[][] { hashBlock0.getBytes(), hashBlock1.getBytes() } ); + } + + @AfterClass + public static void afterAll() { + dataContext.cleanup(); + } + + @Test + public void testBadRequests() { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); + // Unregistered executor + try { + manager.getBlockData("app0", "exec1", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (RuntimeException e) { + assertTrue("Bad error message: " + e, e.getMessage().contains("not registered")); + } + + // Invalid shuffle manager + manager.registerExecutor("app0", "exec2", dataContext.createExecutorInfo("foobar")); + try { + manager.getBlockData("app0", "exec2", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (UnsupportedOperationException e) { + // pass + } + + // Nonexistent shuffle block + manager.registerExecutor("app0", "exec3", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + try { + manager.getBlockData("app0", "exec3", "shuffle_1_1_0"); + fail("Should have failed"); + } catch (Exception e) { + // pass + } + } + + @Test + public void testSortShuffleBlocks() throws IOException { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); + manager.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.sort.SortShuffleManager")); + + InputStream block0Stream = + manager.getBlockData("app0", "exec0", "shuffle_0_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(sortBlock0, block0); + + InputStream block1Stream = + manager.getBlockData("app0", "exec0", "shuffle_0_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(sortBlock1, block1); + } + + @Test + public void testHashShuffleBlocks() throws IOException { + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf); + manager.registerExecutor("app0", "exec0", + dataContext.createExecutorInfo("org.apache.spark.shuffle.hash.HashShuffleManager")); + + InputStream block0Stream = + manager.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); + String block0 = CharStreams.toString(new InputStreamReader(block0Stream)); + block0Stream.close(); + assertEquals(hashBlock0, block0); + + InputStream block1Stream = + manager.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); + String block1 = CharStreams.toString(new InputStreamReader(block1Stream)); + block1Stream.close(); + assertEquals(hashBlock1, block1); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java new file mode 100644 index 0000000000000..254e3a7a32b98 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleCleanupSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + + @Test + public void noCleanupAndCleanup() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", false /* cleanup */); + + assertStillThere(dataContext); + + manager.registerExecutor("app", "exec1", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true /* cleanup */); + + assertCleanedUp(dataContext); + } + + @Test + public void cleanupUsesExecutor() throws IOException { + TestShuffleDataContext dataContext = createSomeData(); + + final AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which does nothing to ensure we're actually using it. + Executor noThreadExecutor = new Executor() { + @Override public void execute(Runnable runnable) { cleanupCalled.set(true); } + }; + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, noThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + + dataContext.cleanup(); + assertCleanedUp(dataContext); + } + + @Test + public void cleanupMultipleExecutors() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); + + manager.registerExecutor("app", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app", "exec1", dataContext1.createExecutorInfo("shuffleMgr")); + manager.applicationRemoved("app", true); + + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + @Test + public void cleanupOnlyRemovedApp() throws IOException { + TestShuffleDataContext dataContext0 = createSomeData(); + TestShuffleDataContext dataContext1 = createSomeData(); + + ExternalShuffleBlockManager manager = new ExternalShuffleBlockManager(conf, sameThreadExecutor); + + manager.registerExecutor("app-0", "exec0", dataContext0.createExecutorInfo("shuffleMgr")); + manager.registerExecutor("app-1", "exec0", dataContext1.createExecutorInfo("shuffleMgr")); + + manager.applicationRemoved("app-nonexistent", true); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-0", true); + assertCleanedUp(dataContext0); + assertStillThere(dataContext1); + + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + + // Make sure it's not an error to cleanup multiple times + manager.applicationRemoved("app-1", true); + assertCleanedUp(dataContext0); + assertCleanedUp(dataContext1); + } + + private void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private void assertCleanedUp(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertFalse(localDir + " wasn't cleaned up", new File(localDir).exists()); + } + } + + private TestShuffleDataContext createSomeData() throws IOException { + Random rand = new Random(123); + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + + dataContext.create(); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), + new byte[][] { "ABC".getBytes(), "DEF".getBytes() } ); + dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, + new byte[][] { "GHI".getBytes(), "JKLMNOPQRSTUVWXYZ".getBytes() } ); + return dataContext; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java new file mode 100644 index 0000000000000..02c10bcb7b261 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -0,0 +1,301 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleIntegrationSuite { + + static String APP_ID = "app-id"; + static String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + static String HASH_MANAGER = "org.apache.spark.shuffle.hash.HashShuffleManager"; + + // Executor 0 is sort-based + static TestShuffleDataContext dataContext0; + // Executor 1 is hash-based + static TestShuffleDataContext dataContext1; + + static ExternalShuffleBlockHandler handler; + static TransportServer server; + static TransportConf conf; + + static byte[][] exec0Blocks = new byte[][] { + new byte[123], + new byte[12345], + new byte[1234567], + }; + + static byte[][] exec1Blocks = new byte[][] { + new byte[321], + new byte[54321], + }; + + @BeforeClass + public static void beforeAll() throws IOException { + Random rand = new Random(); + + for (byte[] block : exec0Blocks) { + rand.nextBytes(block); + } + for (byte[] block: exec1Blocks) { + rand.nextBytes(block); + } + + dataContext0 = new TestShuffleDataContext(2, 5); + dataContext0.create(); + dataContext0.insertSortShuffleData(0, 0, exec0Blocks); + + dataContext1 = new TestShuffleDataContext(6, 2); + dataContext1.create(); + dataContext1.insertHashShuffleData(1, 0, exec1Blocks); + + conf = new TransportConf(new SystemPropertyConfigProvider()); + handler = new ExternalShuffleBlockHandler(conf); + TransportContext transportContext = new TransportContext(conf, handler); + server = transportContext.createServer(); + } + + @AfterClass + public static void afterAll() { + dataContext0.cleanup(); + dataContext1.cleanup(); + server.close(); + } + + @After + public void afterEach() { + handler.applicationRemoved(APP_ID, false /* cleanupLocalDirs */); + } + + class FetchResult { + public Set successBlocks; + public Set failedBlocks; + public List buffers; + + public void releaseBuffers() { + for (ManagedBuffer buffer : buffers) { + buffer.release(); + } + } + } + + // Fetch a set of blocks from a pre-registered executor. + private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { + return fetchBlocks(execId, blockIds, server.getPort()); + } + + // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, + // to allow connecting to invalid servers. + private FetchResult fetchBlocks(String execId, String[] blockIds, int port) throws Exception { + final FetchResult res = new FetchResult(); + res.successBlocks = Collections.synchronizedSet(new HashSet()); + res.failedBlocks = Collections.synchronizedSet(new HashSet()); + res.buffers = Collections.synchronizedList(new LinkedList()); + + final Semaphore requestsRemaining = new Semaphore(0); + + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + client.init(APP_ID); + client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, + new BlockFetchingListener() { + @Override + public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + data.retain(); + res.successBlocks.add(blockId); + res.buffers.add(data); + requestsRemaining.release(); + } + } + } + + @Override + public void onBlockFetchFailure(String blockId, Throwable exception) { + synchronized (this) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.failedBlocks.add(blockId); + requestsRemaining.release(); + } + } + } + }); + + if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { + fail("Timeout getting response from the server"); + } + client.close(); + return res; + } + + @Test + public void testFetchOneSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks[0])); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchThreeSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" }); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), + exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Lists.newArrayList(exec0Blocks)); + exec0Fetch.releaseBuffers(); + } + + @Test + public void testFetchHash() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks); + assertTrue(execFetch.failedBlocks.isEmpty()); + assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks)); + execFetch.releaseBuffers(); + } + + @Test + public void testFetchWrongShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } + + @Test + public void testFetchInvalidShuffle() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager")); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongBlockId() throws Exception { + registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); + FetchResult execFetch = fetchBlocks("exec-1", + new String[] { "rdd_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNonexistent() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_2_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_2_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchWrongExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); + // Both still fail, as we start by checking for all block. + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchUnregisteredExecutor() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-2", + new String[] { "shuffle_0_0_0", "shuffle_1_0_0" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + } + + @Test + public void testFetchNoServer() throws Exception { + System.setProperty("spark.shuffle.io.maxRetries", "0"); + try { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-0", + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, 1 /* port */); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); + } finally { + System.clearProperty("spark.shuffle.io.maxRetries"); + } + } + + private void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) + throws IOException { + ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false); + client.init(APP_ID); + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), + executorId, executorInfo); + } + + private void assertBufferListsEqual(List list0, List list1) + throws Exception { + assertEquals(list0.size(), list1.size()); + for (int i = 0; i < list0.size(); i ++) { + assertBuffersEqual(list0.get(i), new NioManagedBuffer(ByteBuffer.wrap(list1.get(i)))); + } + } + + private void assertBuffersEqual(ManagedBuffer buffer0, ManagedBuffer buffer1) throws Exception { + ByteBuffer nio0 = buffer0.nioByteBuffer(); + ByteBuffer nio1 = buffer1.nioByteBuffer(); + + int len = nio0.remaining(); + assertEquals(nio0.remaining(), nio1.remaining()); + for (int i = 0; i < len; i ++) { + assertEquals(nio0.get(), nio1.get()); + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java new file mode 100644 index 0000000000000..759a12910c94d --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.IOException; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +import org.apache.spark.network.TestUtils; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class ExternalShuffleSecuritySuite { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + TransportServer server; + + @Before + public void beforeEach() { + RpcHandler handler = new SaslRpcHandler(new ExternalShuffleBlockHandler(conf), + new TestSecretKeyHolder("my-app-id", "secret")); + TransportContext context = new TransportContext(conf, handler); + this.server = context.createServer(); + } + + @After + public void afterEach() { + if (server != null) { + server.close(); + server = null; + } + } + + @Test + public void testValid() throws IOException { + validate("my-app-id", "secret"); + } + + @Test + public void testBadAppId() { + try { + validate("wrong-app-id", "secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!")); + } + } + + @Test + public void testBadSecret() { + try { + validate("my-app-id", "bad-secret"); + } catch (Exception e) { + assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response")); + } + } + + /** Creates an ExternalShuffleClient and attempts to register with the server. */ + private void validate(String appId, String secretKey) throws IOException { + ExternalShuffleClient client = + new ExternalShuffleClient(conf, new TestSecretKeyHolder(appId, secretKey), true); + client.init(appId); + // Registration either succeeds or throws an exception. + client.registerWithShuffleServer(TestUtils.getLocalHost(), server.getPort(), "exec0", + new ExecutorShuffleInfo(new String[0], 0, "")); + client.close(); + } + + /** Provides a secret key holder which always returns the given secret key, for a single appId. */ + static class TestSecretKeyHolder implements SecretKeyHolder { + private final String appId; + private final String secretKey; + + TestSecretKeyHolder(String appId, String secretKey) { + this.appId = appId; + this.secretKey = secretKey; + } + + @Override + public String getSaslUser(String appId) { + return "user"; + } + + @Override + public String getSecretKey(String appId) { + if (!appId.equals(this.appId)) { + throw new IllegalArgumentException("Wrong appId!"); + } + return secretKey; + } + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java new file mode 100644 index 0000000000000..842741e3d354f --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.nio.ByteBuffer; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +import com.google.common.collect.Maps; +import io.netty.buffer.Unpooled; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenBlocks; +import org.apache.spark.network.shuffle.protocol.StreamHandle; + +public class OneForOneBlockFetcherSuite { + @Test + public void testFetchOne() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); + } + + @Test + public void testFetchThree() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + for (int i = 0; i < 3; i ++) { + verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i)); + } + } + + @Test + public void testFailure() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", null); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // Each failure will cause a failure to be invoked in all remaining block fetches. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(2)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testFailureAndSuccess() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("b1", null); + blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21]))); + + BlockFetchingListener listener = fetchBlocks(blocks); + + // We may call both success and failure for the same block. + verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2")); + verify(listener, times(1)).onBlockFetchFailure(eq("b2"), (Throwable) any()); + } + + @Test + public void testEmptyBlockFetch() { + try { + fetchBlocks(Maps.newLinkedHashMap()); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("Zero-sized blockIds array", e.getMessage()); + } + } + + /** + * Begins a fetch on the given set of blocks by mocking out the server side of the RPC which + * simply returns the given (BlockId, Block) pairs. + * As "blocks" is a LinkedHashMap, the blocks are guaranteed to be returned in the same order + * that they were inserted in. + * + * If a block's buffer is "null", an exception will be thrown instead. + */ + private BlockFetchingListener fetchBlocks(final LinkedHashMap blocks) { + TransportClient client = mock(TransportClient.class); + BlockFetchingListener listener = mock(BlockFetchingListener.class); + final String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + OneForOneBlockFetcher fetcher = + new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener); + + // Respond to the "OpenBlocks" message with an appropirate ShuffleStreamHandle with streamId 123 + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteArray( + (byte[]) invocationOnMock.getArguments()[0]); + RpcResponseCallback callback = (RpcResponseCallback) invocationOnMock.getArguments()[1]; + callback.onSuccess(new StreamHandle(123, blocks.size()).toByteArray()); + assertEquals(new OpenBlocks("app-id", "exec-id", blockIds), message); + return null; + } + }).when(client).sendRpc((byte[]) any(), (RpcResponseCallback) any()); + + // Respond to each chunk request with a single buffer from our blocks array. + final AtomicInteger expectedChunkIndex = new AtomicInteger(0); + final Iterator blockIterator = blocks.values().iterator(); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + try { + long streamId = (Long) invocation.getArguments()[0]; + int myChunkIndex = (Integer) invocation.getArguments()[1]; + assertEquals(123, streamId); + assertEquals(expectedChunkIndex.getAndIncrement(), myChunkIndex); + + ChunkReceivedCallback callback = (ChunkReceivedCallback) invocation.getArguments()[2]; + ManagedBuffer result = blockIterator.next(); + if (result != null) { + callback.onSuccess(myChunkIndex, result); + } else { + callback.onFailure(myChunkIndex, new RuntimeException("Failed " + myChunkIndex)); + } + } catch (Exception e) { + e.printStackTrace(); + fail("Unexpected failure"); + } + return null; + } + }).when(client).fetchChunk(anyLong(), anyInt(), (ChunkReceivedCallback) any()); + + fetcher.start(); + return listener; + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java new file mode 100644 index 0000000000000..0191fe529e1be --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.LinkedHashSet; +import java.util.Map; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Sets; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.util.SystemPropertyConfigProvider; +import org.apache.spark.network.util.TransportConf; +import static org.apache.spark.network.shuffle.RetryingBlockFetcher.BlockFetchStarter; + +/** + * Tests retry logic by throwing IOExceptions and ensuring that subsequent attempts are made to + * fetch the lost blocks. + */ +public class RetryingBlockFetcherSuite { + + ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); + ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + + @Before + public void beforeEach() { + System.setProperty("spark.shuffle.io.maxRetries", "2"); + System.setProperty("spark.shuffle.io.retryWaitMs", "0"); + } + + @After + public void afterEach() { + System.clearProperty("spark.shuffle.io.maxRetries"); + System.clearProperty("spark.shuffle.io.retryWaitMs"); + } + + @Test + public void testNoFailures() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // Immediately return both blocks successfully. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchSuccess("b0", block0); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testUnrecoverableFailure() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0 throws a non-IOException error, so it will be failed without retry. + ImmutableMap.builder() + .put("b0", new RuntimeException("Ouch!")) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener).onBlockFetchFailure(eq("b0"), (Throwable) any()); + verify(listener).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnFirst() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b0 fails, we will retry both. + ImmutableMap.builder() + .put("b0", new IOException("Connection failed or something")) + .put("b1", block1) + .build(), + ImmutableMap.builder() + .put("b0", block0) + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testSingleIOExceptionOnSecond() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // IOException will cause a retry. Since b1 fails, we will not retry b0. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException("Connection failed or something")) + .build(), + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testTwoIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 returns successfully within 2 retries. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verifyNoMoreInteractions(listener); + } + + @Test + public void testThreeIOExceptions() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, b1's will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new IOException()) + .build(), + // Next, b0 is successful and b1 errors again, so we just request that one. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new IOException()) + .build(), + // b1 errors again, but this was the last retry + ImmutableMap.builder() + .put("b1", new IOException()) + .build(), + // This is not reached -- b1 has failed. + ImmutableMap.builder() + .put("b1", block1) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verifyNoMoreInteractions(listener); + } + + @Test + public void testRetryAndUnrecoverable() throws IOException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + Map[] interactions = new Map[] { + // b0's IOException will trigger retry, subsequent messages will be ignored. + ImmutableMap.builder() + .put("b0", new IOException()) + .put("b1", new RuntimeException()) + .put("b2", block2) + .build(), + // Next, b0 is successful, b1 errors unrecoverably, and b2 triggers a retry. + ImmutableMap.builder() + .put("b0", block0) + .put("b1", new RuntimeException()) + .put("b2", new IOException()) + .build(), + // b2 succeeds in its last retry. + ImmutableMap.builder() + .put("b2", block2) + .build(), + }; + + performInteractions(interactions, listener); + + verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); + verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), (Throwable) any()); + verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); + verifyNoMoreInteractions(listener); + } + + /** + * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. + * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction + * means "respond to the next block fetch request with these Successful buffers and these Failure + * exceptions". We verify that the expected block ids are exactly the ones requested. + * + * If multiple interactions are supplied, they will be used in order. This is useful for encoding + * retries -- the first interaction may include an IOException, which causes a retry of some + * subset of the original blocks in a second interaction. + */ + @SuppressWarnings("unchecked") + private void performInteractions(final Map[] interactions, BlockFetchingListener listener) + throws IOException { + + TransportConf conf = new TransportConf(new SystemPropertyConfigProvider()); + BlockFetchStarter fetchStarter = mock(BlockFetchStarter.class); + + Stubber stub = null; + + // Contains all blockIds that are referenced across all interactions. + final LinkedHashSet blockIds = Sets.newLinkedHashSet(); + + for (final Map interaction : interactions) { + blockIds.addAll(interaction.keySet()); + + Answer answer = new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + try { + // Verify that the RetryingBlockFetcher requested the expected blocks. + String[] requestedBlockIds = (String[]) invocationOnMock.getArguments()[0]; + String[] desiredBlockIds = interaction.keySet().toArray(new String[interaction.size()]); + assertArrayEquals(desiredBlockIds, requestedBlockIds); + + // Now actually invoke the success/failure callbacks on each block. + BlockFetchingListener retryListener = + (BlockFetchingListener) invocationOnMock.getArguments()[1]; + for (Map.Entry block : interaction.entrySet()) { + String blockId = block.getKey(); + Object blockValue = block.getValue(); + + if (blockValue instanceof ManagedBuffer) { + retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + } else if (blockValue instanceof Exception) { + retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + } else { + fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); + } + } + return null; + } catch (Throwable e) { + e.printStackTrace(); + throw e; + } + } + }; + + // This is either the first stub, or should be chained behind the prior ones. + if (stub == null) { + stub = doAnswer(answer); + } else { + stub.doAnswer(answer); + } + } + + assert stub != null; + stub.when(fetchStarter).createAndStart((String[]) any(), (BlockFetchingListener) anyObject()); + String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); + new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); + } +} diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java new file mode 100644 index 0000000000000..76639114df5d9 --- /dev/null +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +import com.google.common.io.Files; + +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; + +/** + * Manages some sort- and hash-based shuffle data, including the creation + * and cleanup of directories that can be read by the {@link ExternalShuffleBlockManager}. + */ +public class TestShuffleDataContext { + public final String[] localDirs; + public final int subDirsPerLocalDir; + + public TestShuffleDataContext(int numLocalDirs, int subDirsPerLocalDir) { + this.localDirs = new String[numLocalDirs]; + this.subDirsPerLocalDir = subDirsPerLocalDir; + } + + public void create() { + for (int i = 0; i < localDirs.length; i ++) { + localDirs[i] = Files.createTempDir().getAbsolutePath(); + + for (int p = 0; p < subDirsPerLocalDir; p ++) { + new File(localDirs[i], String.format("%02x", p)).mkdirs(); + } + } + } + + public void cleanup() { + for (String localDir : localDirs) { + deleteRecursively(new File(localDir)); + } + } + + /** Creates reducer blocks in a sort-based data format within our local dirs. */ + public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_0"; + + OutputStream dataStream = new FileOutputStream( + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".data")); + DataOutputStream indexStream = new DataOutputStream(new FileOutputStream( + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId + ".index"))); + + long offset = 0; + indexStream.writeLong(offset); + for (byte[] block : blocks) { + offset += block.length; + dataStream.write(block); + indexStream.writeLong(offset); + } + + dataStream.close(); + indexStream.close(); + } + + /** Creates reducer blocks in a hash-based data format within our local dirs. */ + public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { + for (int i = 0; i < blocks.length; i ++) { + String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i; + Files.write(blocks[i], + ExternalShuffleBlockManager.getFile(localDirs, subDirsPerLocalDir, blockId)); + } + } + + /** + * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this + * context's directories. + */ + public ExecutorShuffleInfo createExecutorInfo(String shuffleManager) { + return new ExecutorShuffleInfo(localDirs, subDirsPerLocalDir, shuffleManager); + } + + private static void deleteRecursively(File f) { + assert f != null; + if (f.isDirectory()) { + File[] children = f.listFiles(); + if (children != null) { + for (File child : children) { + deleteRecursively(child); + } + } + } + f.delete(); + } +} diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml new file mode 100644 index 0000000000000..acec8f18f2b5c --- /dev/null +++ b/network/yarn/pom.xml @@ -0,0 +1,91 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.3.0-SNAPSHOT + ../../pom.xml + + + org.apache.spark + spark-network-yarn_2.10 + jar + Spark Project YARN Shuffle Service + http://spark.apache.org/ + + network-yarn + + + + + + org.apache.spark + spark-network-shuffle_${scala.binary.version} + ${project.version} + + + + + org.apache.hadoop + hadoop-client + provided + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + + package + + shade + + + + + + + diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java new file mode 100644 index 0000000000000..a34aabe9e78a6 --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn; + +import java.lang.Override; +import java.nio.ByteBuffer; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.yarn.api.records.ApplicationId; +import org.apache.hadoop.yarn.api.records.ContainerId; +import org.apache.hadoop.yarn.server.api.AuxiliaryService; +import org.apache.hadoop.yarn.server.api.ApplicationInitializationContext; +import org.apache.hadoop.yarn.server.api.ApplicationTerminationContext; +import org.apache.hadoop.yarn.server.api.ContainerInitializationContext; +import org.apache.hadoop.yarn.server.api.ContainerTerminationContext; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.sasl.SaslRpcHandler; +import org.apache.spark.network.sasl.ShuffleSecretManager; +import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.TransportServer; +import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.network.yarn.util.HadoopConfigProvider; + +/** + * An external shuffle service used by Spark on Yarn. + * + * This is intended to be a long-running auxiliary service that runs in the NodeManager process. + * A Spark application may connect to this service by setting `spark.shuffle.service.enabled`. + * The application also automatically derives the service port through `spark.shuffle.service.port` + * specified in the Yarn configuration. This is so that both the clients and the server agree on + * the same port to communicate on. + * + * The service also optionally supports authentication. This ensures that executors from one + * application cannot read the shuffle files written by those from another. This feature can be + * enabled by setting `spark.authenticate` in the Yarn configuration before starting the NM. + * Note that the Spark application must also set `spark.authenticate` manually and, unlike in + * the case of the service port, will not inherit this setting from the Yarn configuration. This + * is because an application running on the same Yarn cluster may choose to not use the external + * shuffle service, in which case its setting of `spark.authenticate` should be independent of + * the service's. + */ +public class YarnShuffleService extends AuxiliaryService { + private final Logger logger = LoggerFactory.getLogger(YarnShuffleService.class); + + // Port on which the shuffle server listens for fetch requests + private static final String SPARK_SHUFFLE_SERVICE_PORT_KEY = "spark.shuffle.service.port"; + private static final int DEFAULT_SPARK_SHUFFLE_SERVICE_PORT = 7337; + + // Whether the shuffle server should authenticate fetch requests + private static final String SPARK_AUTHENTICATE_KEY = "spark.authenticate"; + private static final boolean DEFAULT_SPARK_AUTHENTICATE = false; + + // An entity that manages the shuffle secret per application + // This is used only if authentication is enabled + private ShuffleSecretManager secretManager; + + // The actual server that serves shuffle files + private TransportServer shuffleServer = null; + + public YarnShuffleService() { + super("spark_shuffle"); + logger.info("Initializing YARN shuffle service for Spark"); + } + + /** + * Return whether authentication is enabled as specified by the configuration. + * If so, fetch requests will fail unless the appropriate authentication secret + * for the application is provided. + */ + private boolean isAuthenticationEnabled() { + return secretManager != null; + } + + /** + * Start the shuffle server with the given configuration. + */ + @Override + protected void serviceInit(Configuration conf) { + TransportConf transportConf = new TransportConf(new HadoopConfigProvider(conf)); + // If authentication is enabled, set up the shuffle server to use a + // special RPC handler that filters out unauthenticated fetch requests + boolean authEnabled = conf.getBoolean(SPARK_AUTHENTICATE_KEY, DEFAULT_SPARK_AUTHENTICATE); + RpcHandler rpcHandler = new ExternalShuffleBlockHandler(transportConf); + if (authEnabled) { + secretManager = new ShuffleSecretManager(); + rpcHandler = new SaslRpcHandler(rpcHandler, secretManager); + } + + int port = conf.getInt( + SPARK_SHUFFLE_SERVICE_PORT_KEY, DEFAULT_SPARK_SHUFFLE_SERVICE_PORT); + TransportContext transportContext = new TransportContext(transportConf, rpcHandler); + shuffleServer = transportContext.createServer(port); + String authEnabledString = authEnabled ? "enabled" : "not enabled"; + logger.info("Started YARN shuffle service for Spark on port {}. " + + "Authentication is {}.", port, authEnabledString); + } + + @Override + public void initializeApplication(ApplicationInitializationContext context) { + String appId = context.getApplicationId().toString(); + try { + ByteBuffer shuffleSecret = context.getApplicationDataForService(); + logger.info("Initializing application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.registerApp(appId, shuffleSecret); + } + } catch (Exception e) { + logger.error("Exception when initializing application {}", appId, e); + } + } + + @Override + public void stopApplication(ApplicationTerminationContext context) { + String appId = context.getApplicationId().toString(); + try { + logger.info("Stopping application {}", appId); + if (isAuthenticationEnabled()) { + secretManager.unregisterApp(appId); + } + } catch (Exception e) { + logger.error("Exception when stopping application {}", appId, e); + } + } + + @Override + public void initializeContainer(ContainerInitializationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Initializing container {}", containerId); + } + + @Override + public void stopContainer(ContainerTerminationContext context) { + ContainerId containerId = context.getContainerId(); + logger.info("Stopping container {}", containerId); + } + + /** + * Close the shuffle server to clean up any associated state. + */ + @Override + protected void serviceStop() { + try { + if (shuffleServer != null) { + shuffleServer.close(); + } + } catch (Exception e) { + logger.error("Exception when stopping service", e); + } + } + + // Not currently used + @Override + public ByteBuffer getMetaData() { + return ByteBuffer.allocate(0); + } + +} diff --git a/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java new file mode 100644 index 0000000000000..884861752e80d --- /dev/null +++ b/network/yarn/src/main/java/org/apache/spark/network/yarn/util/HadoopConfigProvider.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn.util; + +import java.util.NoSuchElementException; + +import org.apache.hadoop.conf.Configuration; + +import org.apache.spark.network.util.ConfigProvider; + +/** Use the Hadoop configuration to obtain config values. */ +public class HadoopConfigProvider extends ConfigProvider { + private final Configuration conf; + + public HadoopConfigProvider(Configuration conf) { + this.conf = conf; + } + + @Override + public String get(String name) { + String value = conf.get(name); + if (value == null) { + throw new NoSuchElementException(name); + } + return value; + } +} diff --git a/pom.xml b/pom.xml index 7756c89b00cad..b7df53d3e5eb1 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -91,34 +91,32 @@ graphx mllib tools + network/common + network/shuffle streaming sql/catalyst sql/core sql/hive - repl assembly external/twitter - external/kafka external/flume external/flume-sink - external/zeromq external/mqtt + external/zeromq examples + repl UTF-8 UTF-8 - + org.spark-project.akka + 2.3.4-spark 1.6 spark - 2.10.4 - 2.10 2.0.1 0.18.1 shaded-protobuf - org.spark-project.akka - 2.2.3-shaded-protobuf 1.7.5 1.2.17 1.0.4 @@ -127,11 +125,15 @@ 0.94.6 1.4.0 3.4.5 - 0.12.0 - 1.4.3 + + 0.13.1a + + 0.13.1 + 10.10.1.1 + 1.6.0rc3 1.2.3 8.1.14.v20131031 - 0.3.6 + 0.5.0 3.0.0 1.7.6 @@ -139,9 +141,14 @@ 1.8.3 1.1.0 4.2.6 - + 3.1.1 + ${project.build.directory}/spark-test-classpath.txt 64m 512m + 2.10.4 + 2.10 + ${scala.version} + org.scala-lang @@ -223,6 +230,18 @@ false + + + spark-staging-1038 + Spark 1.2.0 Staging (1038) + https://repository.apache.org/content/repositories/orgapachespark-1038/ + + true + + + false + + @@ -237,8 +256,65 @@ + + + + org.spark-project.spark + unused + 1.0.0 + + + + org.codehaus.groovy + groovy-all + 2.3.7 + provided + + + + ${jline.groupid} + jline + ${jline.version} + + + com.twitter + chill_${scala.binary.version} + ${chill.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + + + com.twitter + chill-java + ${chill.version} + + + org.ow2.asm + asm + + + org.ow2.asm + asm-commons + + + org.eclipse.jetty jetty-util @@ -278,14 +354,19 @@ org.apache.commons commons-math3 - 3.3 - test + ${commons.math3.version} com.google.code.findbugs jsr305 1.3.9 + + org.seleniumhq.selenium + selenium-java + 2.42.2 + test + org.slf4j slf4j-api @@ -320,7 +401,7 @@ org.xerial.snappy snappy-java - 1.1.1.3 + 1.1.1.6 net.jpountz.lz4 @@ -349,36 +430,6 @@ protobuf-java ${protobuf.version} - - com.twitter - chill_${scala.binary.version} - ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - - - - com.twitter - chill-java - ${chill.version} - - - org.ow2.asm - asm - - - org.ow2.asm - asm-commons - - - ${akka.group} akka-actor_${scala.binary.version} @@ -399,11 +450,6 @@ akka-testkit_${scala.binary.version} ${akka.version} - - colt - colt - 1.2.0 - org.apache.mesos mesos @@ -416,6 +462,11 @@ + + org.roaringbitmap + RoaringBitmap + 0.4.5 + commons-net commons-net @@ -429,7 +480,7 @@ org.apache.derby derby - 10.4.2.0 + ${derby.version} com.codahale.metrics @@ -466,11 +517,6 @@ scala-reflect ${scala.version} - - org.scala-lang - jline - ${scala.version} - org.scala-lang scala-library @@ -489,7 +535,7 @@ org.scalatest scalatest_${scala.binary.version} - 2.1.5 + 2.2.1 test @@ -876,9 +922,9 @@ by Spark SQL for code generation. --> - org.scalamacros - paradise_${scala.version} - ${scala.macros.version} + org.scalamacros + paradise_${scala.version} + ${scala.macros.version} @@ -919,6 +965,9 @@ ${session.executionRootDirectory} 1 false + false + ${test_classpath} + true @@ -976,10 +1025,77 @@ + + org.apache.maven.plugins + maven-javadoc-plugin + 2.10.1 + + + + org.apache.maven.plugins + maven-dependency-plugin + 2.9 + + + test-compile + + build-classpath + + + test + ${test_classpath_file} + + + + + + + + org.codehaus.gmavenplus + gmavenplus-plugin + 1.2 + + + process-test-classes + + execute + + + + + + + + + + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + org.spark-project.spark:unused + + + + + + package + + shade + + + + org.apache.maven.plugins maven-enforcer-plugin @@ -1107,8 +1223,31 @@ + + doclint-java8-disable + + [1.8,) + + + + + + org.apache.maven.plugins + maven-javadoc-plugin + + -Xdoclint:all -Xdoclint:-missing + + + + + + + + hadoop-0.23 @@ -1138,6 +1277,7 @@ 2.3.0 2.5.0 0.9.0 + 3.1.1 hadoop2 @@ -1148,6 +1288,7 @@ 2.4.0 2.5.0 0.9.0 + 3.1.1 hadoop2 @@ -1163,14 +1304,12 @@ yarn yarn + network/yarn mapr3 - - false - 1.0.3-mapr-3.0.3 2.3.0-mapr-4.0.0-FCS @@ -1181,9 +1320,6 @@ mapr4 - - false - 2.3.0-mapr-4.0.0-FCS 2.3.0-mapr-4.0.0-FCS @@ -1213,9 +1349,6 @@ hadoop-provided - - false - org.apache.hadoop @@ -1260,16 +1393,57 @@ + + hive-thriftserver + + sql/hive-thriftserver + + + + hive-0.12.0 + + 0.12.0-protobuf-2.5 + 0.12.0 + 10.4.2.0 + + + + hive-0.13.1 + + 0.13.1a + 0.13.1 + 10.10.1.1 + + - hive + scala-2.10 - false + !scala-2.11 + + 2.10.4 + 2.10 + ${scala.version} + org.scala-lang + - sql/hive-thriftserver + external/kafka + + scala-2.11 + + scala-2.11 + + + 2.11.2 + 2.11 + 2.12 + jline + + + diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 39f8ba4745737..f0cbf4e57b8c5 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -30,9 +30,9 @@ object MimaBuild { def excludeMember(fullName: String) = Seq( ProblemFilters.exclude[MissingMethodProblem](fullName), - // Sometimes excluded methods have default arguments and + // Sometimes excluded methods have default arguments and // they are translated into public methods/fields($default$) in generated - // bytecode. It is not possible to exhustively list everything. + // bytecode. It is not possible to exhaustively list everything. // But this should be okay. ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$2"), ProblemFilters.exclude[MissingMethodProblem](fullName+"$default$1"), @@ -91,9 +91,9 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.1.0" + val previousSparkVersion = "1.2.0" val fullId = "spark-" + projectRef.project + "_2.10" - mimaDefaultSettings ++ + mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), binaryIssueFilters ++= ignoredABIProblems(sparkHome, version.value)) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index d499302124461..230239aa40500 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -33,6 +33,28 @@ import com.typesafe.tools.mima.core._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.3") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx"), + // These are needed if checking against the sbt build, since they are part of + // the maven-generated artifacts in the 1.2 build. + MimaBuild.excludeSparkPackage("unused"), + ProblemFilters.exclude[MissingClassProblem]("com.google.common.base.Optional") + ) ++ Seq( + // SPARK-2321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfoImpl.this"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.SparkStageInfo.submissionTime") + ) ++ Seq( + // SPARK-4614 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.randn"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrices.rand") + ) + case v if v.startsWith("1.2") => Seq( MimaBuild.excludeSparkPackage("deploy"), @@ -50,7 +72,45 @@ object MimaExcludes { "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), // MapStatus should be private[spark] ProblemFilters.exclude[IncompatibleTemplateDefProblem]( - "org.apache.spark.scheduler.MapStatus") + "org.apache.spark.scheduler.MapStatus"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.PathResolver"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.network.netty.client.BlockClientListener"), + + // TaskContext was promoted to Abstract class + ProblemFilters.exclude[AbstractClassProblem]( + "org.apache.spark.TaskContext"), + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.util.collection.SortDataFormat") + ) ++ Seq( + // Adding new methods to the JavaRDDLike trait: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.takeAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachPartitionAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.countAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.collectAsync") + ) ++ Seq( + // SPARK-3822 + ProblemFilters.exclude[IncompatibleResultTypeProblem]( + "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") + ) ++ Seq( + // SPARK-1209 + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"), + ProblemFilters.exclude[MissingTypesProblem]( + "org.apache.spark.rdd.PairRDDFunctions") + ) ++ Seq( + // SPARK-4062 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this") ) case v if v.startsWith("1.1") => diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 01a5b20e7c51d..b16ed66aeb3c3 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -22,7 +22,7 @@ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.genjavadocSettings -import org.scalastyle.sbt.ScalastylePlugin.{Settings => ScalaStyleSettings} +import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import net.virtualvoid.sbt.graph.Plugin.graphSettings @@ -31,18 +31,19 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, - sql, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, - streamingTwitter, streamingZeromq) = + sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, + streamingMqtt, streamingTwitter, streamingZeromq) = Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", - "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", - "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) + "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", + "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", + "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") - .map(ProjectRef(buildLocation, _)) + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, + sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha", + "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _)) - val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") - .map(ProjectRef(buildLocation, _)) + val assemblyProjects@Seq(assembly, examples, networkYarn) = + Seq("assembly", "examples", "network-yarn").map(ProjectRef(buildLocation, _)) val tools = ProjectRef(buildLocation, "tools") // Root project. @@ -67,8 +68,8 @@ object SparkBuild extends PomBuild { profiles ++= Seq("spark-ganglia-lgpl") } if (Properties.envOrNone("SPARK_HIVE").isDefined) { - println("NOTE: SPARK_HIVE is deprecated, please use -Phive flag.") - profiles ++= Seq("hive") + println("NOTE: SPARK_HIVE is deprecated, please use -Phive and -Phive-thriftserver flags.") + profiles ++= Seq("hive", "hive-thriftserver") } Properties.envOrNone("SPARK_HADOOP_VERSION") match { case Some(v) => @@ -90,13 +91,23 @@ object SparkBuild extends PomBuild { profiles } - override val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { + override val profiles = { + val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match { case None => backwardCompatibility case Some(v) => if (backwardCompatibility.nonEmpty) println("Note: We ignore environment variables, when use of profile is detected in " + "conjunction with environment variable.") v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq + } + + if (System.getProperty("scala-2.11") == "") { + // To activate scala-2.11 profile, replace empty property value to non-empty value + // in the same way as Maven which handles -Dname as -Dname=true before executes build process. + // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082 + System.setProperty("scala-2.11", "true") + } + profiles } Properties.envOrNone("SBT_MAVEN_PROPERTIES") match { @@ -110,12 +121,13 @@ object SparkBuild extends PomBuild { lazy val MavenCompile = config("m2r") extend(Compile) lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy") - lazy val sharedSettings = graphSettings ++ ScalaStyleSettings ++ genjavadocSettings ++ Seq ( + lazy val sharedSettings = graphSettings ++ genjavadocSettings ++ Seq ( javaHome := Properties.envOrNone("JAVA_HOME").map(file), incOptions := incOptions.value.withNameHashing(true), retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", publishMavenStyle := true, + unidocGenjavadocVersion := "0.8", resolvers += Resolver.mavenLocal, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), @@ -124,7 +136,12 @@ object SparkBuild extends PomBuild { }, publishMavenStyle in MavenCompile := true, publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal), - publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn + publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn, + + javacOptions in (Compile, doc) ++= { + val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3) + if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty + } ) def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = { @@ -134,14 +151,17 @@ object SparkBuild extends PomBuild { // Note ordering of these settings matter. /* Enable shared settings on all projects */ - (allProjects ++ optionallyEnabledProjects ++ assemblyProjects).foreach(enable(sharedSettings)) + (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools)) + .foreach(enable(sharedSettings ++ ExludedDependencies.settings)) /* Enable tests settings for all projects except examples, assembly and tools */ (allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings)) // TODO: Add Sql to mima checks allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl, - streamingFlumeSink).contains(x)).foreach(x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)) + streamingFlumeSink, networkCommon, networkShuffle, networkYarn).contains(x)).foreach { + x => enable(MimaBuild.mimaSettings(sparkHome, x))(x) + } /* Enable Assembly for all assembly projects */ assemblyProjects.foreach(enable(Assembly.settings)) @@ -174,6 +194,16 @@ object Flume { lazy val settings = sbtavro.SbtAvro.avroSettings } +/** + This excludes library dependencies in sbt, which are specified in maven but are + not needed by sbt build. + */ +object ExludedDependencies { + lazy val settings = Seq( + libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") } + ) +} + /** * Following project only exists to pull previous artifacts of Spark for generating * Mima ignores. For more information see: SPARK 2071 @@ -184,12 +214,14 @@ object OldDeps { def versionArtifact(id: String): Option[sbt.ModuleID] = { val fullId = id + "_2.10" - Some("org.apache.spark" % fullId % "1.1.0") + Some("org.apache.spark" % fullId % "1.2.0") } def oldDepsSettings() = Defaults.coreDefaultSettings ++ Seq( name := "old-deps", scalaVersion := "2.10.4", + // TODO: remove this as soon as 1.2.0 is published on Maven central. + resolvers += "spark-staging-1038" at "https://repository.apache.org/content/repositories/orgapachespark-1038/", retrieveManaged := true, retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", @@ -251,7 +283,11 @@ object Hive { |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, - cleanupCommands in console := "sparkContext.stop()" + cleanupCommands in console := "sparkContext.stop()", + // Some of our log4j jars make it impossible to submit jobs from this JVM to Hive Map/Reduce + // in order to generate golden files. This is only required for developers who are adding new + // new query tests. + fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") } ) } @@ -262,8 +298,15 @@ object Assembly { lazy val settings = assemblySettings ++ Seq( test in assembly := {}, - jarName in assembly <<= (version, moduleName) map { (v, mName) => mName + "-"+v + "-hadoop" + - Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" }, + jarName in assembly <<= (version, moduleName) map { (v, mName) => + if (mName.contains("network-yarn")) { + // This must match the same name used in maven (see network/yarn/pom.xml) + "spark-" + v + "-yarn-shuffle.jar" + } else { + mName + "-" + v + "-hadoop" + + Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" + } + }, mergeStrategy in assembly := { case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard @@ -294,7 +337,7 @@ object Unidoc { unidocProjectFilter in(ScalaUnidoc, unidoc) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { @@ -340,13 +383,18 @@ object TestSettings { javaOptions in Test += "-Dspark.testing=1", javaOptions in Test += "-Dspark.port.maxRetries=100", javaOptions in Test += "-Dspark.ui.enabled=false", + javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", + javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") .map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g" .split(" ").toSeq, + // This places test scope jars on the classpath of executors during tests. + javaOptions in Test += + "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files. + map(_.getAbsolutePath).mkString(":").stripSuffix(":"), javaOptions += "-Xmx3g", - // Show full stack trace and duration in test cases. testOptions in Test += Tests.Argument("-oDF"), testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"), diff --git a/project/build.properties b/project/build.properties index c12ef652adfcb..32a3aeefaf9fb 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.5 +sbt.version=0.13.6 diff --git a/project/plugins.sbt b/project/plugins.sbt index 8096c61414660..ee45b6a51905e 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -4,6 +4,8 @@ resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline. resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" +resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" + addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") @@ -17,7 +19,7 @@ addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0") addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.4") -addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.4.0") +addSbtPlugin("org.scalastyle" %% "scalastyle-sbt-plugin" % "0.6.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.1.6") diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 3ef2d5451da0d..8863f272da415 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -26,7 +26,7 @@ import sbt.Keys._ object SparkPluginDef extends Build { lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader) lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings) - lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git") + lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id") // There is actually no need to publish this artifact. def styleSettings = Defaults.defaultSettings ++ Seq ( diff --git a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala b/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala deleted file mode 100644 index 80d3faa3fe749..0000000000000 --- a/project/spark-style/src/main/scala/org/apache/spark/scalastyle/SparkSpaceAfterCommentStartChecker.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package org.apache.spark.scalastyle - -import java.util.regex.Pattern - -import org.scalastyle.{PositionError, ScalariformChecker, ScalastyleError} -import scalariform.lexer.{MultiLineComment, ScalaDocComment, SingleLineComment, Token} -import scalariform.parser.CompilationUnit - -class SparkSpaceAfterCommentStartChecker extends ScalariformChecker { - val errorKey: String = "insert.a.single.space.after.comment.start.and.before.end" - - private def multiLineCommentRegex(comment: Token) = - Pattern.compile( """/\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || - Pattern.compile( """/\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() - - private def scalaDocPatternRegex(comment: Token) = - Pattern.compile( """/\*\*\S+.*""", Pattern.DOTALL).matcher(comment.text.trim).matches() || - Pattern.compile( """/\*\*.*\S\*/""", Pattern.DOTALL).matcher(comment.text.trim).matches() - - private def singleLineCommentRegex(comment: Token): Boolean = - comment.text.trim.matches( """//\S+.*""") && !comment.text.trim.matches( """///+""") - - override def verify(ast: CompilationUnit): List[ScalastyleError] = { - ast.tokens - .filter(hasComment) - .map { - _.associatedWhitespaceAndComments.comments.map { - case x: SingleLineComment if singleLineCommentRegex(x.token) => Some(x.token.offset) - case x: MultiLineComment if multiLineCommentRegex(x.token) => Some(x.token.offset) - case x: ScalaDocComment if scalaDocPatternRegex(x.token) => Some(x.token.offset) - case _ => None - }.flatten - }.flatten.map(PositionError(_)) - } - - - private def hasComment(x: Token) = - x.associatedWhitespaceAndComments != null && !x.associatedWhitespaceAndComments.comments.isEmpty - -} diff --git a/python/.gitignore b/python/.gitignore index 80b361ffbd51c..52128cf844a79 100644 --- a/python/.gitignore +++ b/python/.gitignore @@ -1,5 +1,5 @@ *.pyc -docs/ +docs/_build/ pyspark.egg-info build/ dist/ diff --git a/python/docs/conf.py b/python/docs/conf.py index c368cf81a003b..e58d97ae6a746 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = '1.1' +version = '1.2-SNAPSHOT' # The full version, including alpha/beta/rc tags. -release = '' +release = '1.2-SNAPSHOT' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. @@ -102,7 +102,7 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'default' +html_theme = 'nature' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the @@ -121,7 +121,7 @@ # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +html_logo = "../../docs/img/spark-logo-hd.png" # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 @@ -131,7 +131,7 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +#html_static_path = ['_static'] # Add any extra paths that contain custom files (such as robots.txt or # .htaccess) here, relative to this directory. These files are copied @@ -154,10 +154,10 @@ #html_additional_pages = {} # If false, no module index is generated. -#html_domain_indices = True +html_domain_indices = False # If false, no index is generated. -#html_use_index = True +html_use_index = False # If true, the index is split into individual pages for each letter. #html_split_index = False diff --git a/python/docs/epytext.py b/python/docs/epytext.py index 61d731bff570d..e884d5e6b19c7 100644 --- a/python/docs/epytext.py +++ b/python/docs/epytext.py @@ -1,11 +1,11 @@ import re RULES = ( - (r"<[\w.]+>", r""), + (r"<(!BLANKLINE)[\w.]+>", r""), (r"L{([\w.()]+)}", r":class:`\1`"), (r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"), (r"C{([\w.()]+)}", r":class:`\1`"), - (r"[IBCM]{(.+)}", r"`\1`"), + (r"[IBCM]{([^}]+)}", r"`\1`"), ('pyspark.rdd.RDD', 'RDD'), ) diff --git a/python/docs/index.rst b/python/docs/index.rst index 25b3f9bd93e63..703bef644de28 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -3,7 +3,7 @@ You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. -Welcome to PySpark API reference! +Welcome to Spark Python API Docs! =================================== Contents: @@ -13,6 +13,7 @@ Contents: pyspark pyspark.sql + pyspark.streaming pyspark.mllib @@ -24,14 +25,12 @@ Core classes: Main entry point for Spark functionality. :class:`pyspark.RDD` - + A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Indices and tables ================== -* :ref:`genindex` -* :ref:`modindex` * :ref:`search` diff --git a/python/docs/make.bat b/python/docs/make.bat index adad44fd7536a..cc29acdc19686 100644 --- a/python/docs/make.bat +++ b/python/docs/make.bat @@ -1,242 +1,6 @@ -@ECHO OFF - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set BUILDDIR=_build -set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . -set I18NSPHINXOPTS=%SPHINXOPTS% . -if NOT "%PAPER%" == "" ( - set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% - set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% -) - -if "%1" == "" goto help - -if "%1" == "help" ( - :help - echo.Please use `make ^` where ^ is one of - echo. html to make standalone HTML files - echo. dirhtml to make HTML files named index.html in directories - echo. singlehtml to make a single large HTML file - echo. pickle to make pickle files - echo. json to make JSON files - echo. htmlhelp to make HTML files and a HTML help project - echo. qthelp to make HTML files and a qthelp project - echo. devhelp to make HTML files and a Devhelp project - echo. epub to make an epub - echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter - echo. text to make text files - echo. man to make manual pages - echo. texinfo to make Texinfo files - echo. gettext to make PO message catalogs - echo. changes to make an overview over all changed/added/deprecated items - echo. xml to make Docutils-native XML files - echo. pseudoxml to make pseudoxml-XML files for display purposes - echo. linkcheck to check all external links for integrity - echo. doctest to run all doctests embedded in the documentation if enabled - goto end -) - -if "%1" == "clean" ( - for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i - del /q /s %BUILDDIR%\* - goto end -) - - -%SPHINXBUILD% 2> nul -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -if "%1" == "html" ( - %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/html. - goto end -) - -if "%1" == "dirhtml" ( - %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. - goto end -) - -if "%1" == "singlehtml" ( - %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. - goto end -) - -if "%1" == "pickle" ( - %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the pickle files. - goto end -) - -if "%1" == "json" ( - %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can process the JSON files. - goto end -) - -if "%1" == "htmlhelp" ( - %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run HTML Help Workshop with the ^ -.hhp project file in %BUILDDIR%/htmlhelp. - goto end -) - -if "%1" == "qthelp" ( - %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; now you can run "qcollectiongenerator" with the ^ -.qhcp project file in %BUILDDIR%/qthelp, like this: - echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp - echo.To view the help file: - echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc - goto end -) - -if "%1" == "devhelp" ( - %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. - goto end -) - -if "%1" == "epub" ( - %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The epub file is in %BUILDDIR%/epub. - goto end -) - -if "%1" == "latex" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - if errorlevel 1 exit /b 1 - echo. - echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdf" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "latexpdfja" ( - %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex - cd %BUILDDIR%/latex - make all-pdf-ja - cd %BUILDDIR%/.. - echo. - echo.Build finished; the PDF files are in %BUILDDIR%/latex. - goto end -) - -if "%1" == "text" ( - %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The text files are in %BUILDDIR%/text. - goto end -) - -if "%1" == "man" ( - %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The manual pages are in %BUILDDIR%/man. - goto end -) - -if "%1" == "texinfo" ( - %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. - goto end -) - -if "%1" == "gettext" ( - %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The message catalogs are in %BUILDDIR%/locale. - goto end -) - -if "%1" == "changes" ( - %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes - if errorlevel 1 exit /b 1 - echo. - echo.The overview file is in %BUILDDIR%/changes. - goto end -) - -if "%1" == "linkcheck" ( - %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck - if errorlevel 1 exit /b 1 - echo. - echo.Link check complete; look for any errors in the above output ^ -or in %BUILDDIR%/linkcheck/output.txt. - goto end -) - -if "%1" == "doctest" ( - %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest - if errorlevel 1 exit /b 1 - echo. - echo.Testing of doctests in the sources finished, look at the ^ -results in %BUILDDIR%/doctest/output.txt. - goto end -) - -if "%1" == "xml" ( - %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The XML files are in %BUILDDIR%/xml. - goto end -) - -if "%1" == "pseudoxml" ( - %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml - if errorlevel 1 exit /b 1 - echo. - echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. - goto end -) - -:end +@ECHO OFF + +rem This is the entry point for running Sphinx documentation. To avoid polluting the +rem environment, it just launches a new cmd to do the real work. + +cmd /V /E /C %~dp0make2.bat %* diff --git a/python/docs/make2.bat b/python/docs/make2.bat new file mode 100644 index 0000000000000..05d22eb5cdd23 --- /dev/null +++ b/python/docs/make2.bat @@ -0,0 +1,243 @@ +@ECHO OFF + +REM Command file for Sphinx documentation + + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set BUILDDIR=_build +set ALLSPHINXOPTS=-d %BUILDDIR%/doctrees %SPHINXOPTS% . +set I18NSPHINXOPTS=%SPHINXOPTS% . +if NOT "%PAPER%" == "" ( + set ALLSPHINXOPTS=-D latex_paper_size=%PAPER% %ALLSPHINXOPTS% + set I18NSPHINXOPTS=-D latex_paper_size=%PAPER% %I18NSPHINXOPTS% +) + +if "%1" == "" goto help + +if "%1" == "help" ( + :help + echo.Please use `make ^` where ^ is one of + echo. html to make standalone HTML files + echo. dirhtml to make HTML files named index.html in directories + echo. singlehtml to make a single large HTML file + echo. pickle to make pickle files + echo. json to make JSON files + echo. htmlhelp to make HTML files and a HTML help project + echo. qthelp to make HTML files and a qthelp project + echo. devhelp to make HTML files and a Devhelp project + echo. epub to make an epub + echo. latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter + echo. text to make text files + echo. man to make manual pages + echo. texinfo to make Texinfo files + echo. gettext to make PO message catalogs + echo. changes to make an overview over all changed/added/deprecated items + echo. xml to make Docutils-native XML files + echo. pseudoxml to make pseudoxml-XML files for display purposes + echo. linkcheck to check all external links for integrity + echo. doctest to run all doctests embedded in the documentation if enabled + goto end +) + +if "%1" == "clean" ( + for /d %%i in (%BUILDDIR%\*) do rmdir /q /s %%i + del /q /s %BUILDDIR%\* + goto end +) + + +%SPHINXBUILD% 2> nul +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "html" ( + %SPHINXBUILD% -b html %ALLSPHINXOPTS% %BUILDDIR%/html + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/html. + goto end +) + +if "%1" == "dirhtml" ( + %SPHINXBUILD% -b dirhtml %ALLSPHINXOPTS% %BUILDDIR%/dirhtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/dirhtml. + goto end +) + +if "%1" == "singlehtml" ( + %SPHINXBUILD% -b singlehtml %ALLSPHINXOPTS% %BUILDDIR%/singlehtml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The HTML pages are in %BUILDDIR%/singlehtml. + goto end +) + +if "%1" == "pickle" ( + %SPHINXBUILD% -b pickle %ALLSPHINXOPTS% %BUILDDIR%/pickle + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the pickle files. + goto end +) + +if "%1" == "json" ( + %SPHINXBUILD% -b json %ALLSPHINXOPTS% %BUILDDIR%/json + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can process the JSON files. + goto end +) + +if "%1" == "htmlhelp" ( + %SPHINXBUILD% -b htmlhelp %ALLSPHINXOPTS% %BUILDDIR%/htmlhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run HTML Help Workshop with the ^ +.hhp project file in %BUILDDIR%/htmlhelp. + goto end +) + +if "%1" == "qthelp" ( + %SPHINXBUILD% -b qthelp %ALLSPHINXOPTS% %BUILDDIR%/qthelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; now you can run "qcollectiongenerator" with the ^ +.qhcp project file in %BUILDDIR%/qthelp, like this: + echo.^> qcollectiongenerator %BUILDDIR%\qthelp\pyspark.qhcp + echo.To view the help file: + echo.^> assistant -collectionFile %BUILDDIR%\qthelp\pyspark.ghc + goto end +) + +if "%1" == "devhelp" ( + %SPHINXBUILD% -b devhelp %ALLSPHINXOPTS% %BUILDDIR%/devhelp + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. + goto end +) + +if "%1" == "epub" ( + %SPHINXBUILD% -b epub %ALLSPHINXOPTS% %BUILDDIR%/epub + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The epub file is in %BUILDDIR%/epub. + goto end +) + +if "%1" == "latex" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + if errorlevel 1 exit /b 1 + echo. + echo.Build finished; the LaTeX files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdf" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "latexpdfja" ( + %SPHINXBUILD% -b latex %ALLSPHINXOPTS% %BUILDDIR%/latex + cd %BUILDDIR%/latex + make all-pdf-ja + cd %BUILDDIR%/.. + echo. + echo.Build finished; the PDF files are in %BUILDDIR%/latex. + goto end +) + +if "%1" == "text" ( + %SPHINXBUILD% -b text %ALLSPHINXOPTS% %BUILDDIR%/text + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The text files are in %BUILDDIR%/text. + goto end +) + +if "%1" == "man" ( + %SPHINXBUILD% -b man %ALLSPHINXOPTS% %BUILDDIR%/man + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The manual pages are in %BUILDDIR%/man. + goto end +) + +if "%1" == "texinfo" ( + %SPHINXBUILD% -b texinfo %ALLSPHINXOPTS% %BUILDDIR%/texinfo + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The Texinfo files are in %BUILDDIR%/texinfo. + goto end +) + +if "%1" == "gettext" ( + %SPHINXBUILD% -b gettext %I18NSPHINXOPTS% %BUILDDIR%/locale + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The message catalogs are in %BUILDDIR%/locale. + goto end +) + +if "%1" == "changes" ( + %SPHINXBUILD% -b changes %ALLSPHINXOPTS% %BUILDDIR%/changes + if errorlevel 1 exit /b 1 + echo. + echo.The overview file is in %BUILDDIR%/changes. + goto end +) + +if "%1" == "linkcheck" ( + %SPHINXBUILD% -b linkcheck %ALLSPHINXOPTS% %BUILDDIR%/linkcheck + if errorlevel 1 exit /b 1 + echo. + echo.Link check complete; look for any errors in the above output ^ +or in %BUILDDIR%/linkcheck/output.txt. + goto end +) + +if "%1" == "doctest" ( + %SPHINXBUILD% -b doctest %ALLSPHINXOPTS% %BUILDDIR%/doctest + if errorlevel 1 exit /b 1 + echo. + echo.Testing of doctests in the sources finished, look at the ^ +results in %BUILDDIR%/doctest/output.txt. + goto end +) + +if "%1" == "xml" ( + %SPHINXBUILD% -b xml %ALLSPHINXOPTS% %BUILDDIR%/xml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The XML files are in %BUILDDIR%/xml. + goto end +) + +if "%1" == "pseudoxml" ( + %SPHINXBUILD% -b pseudoxml %ALLSPHINXOPTS% %BUILDDIR%/pseudoxml + if errorlevel 1 exit /b 1 + echo. + echo.Build finished. The pseudo-XML files are in %BUILDDIR%/pseudoxml. + goto end +) + +:end diff --git a/python/docs/modules.rst b/python/docs/modules.rst deleted file mode 100644 index 183564659fbcf..0000000000000 --- a/python/docs/modules.rst +++ /dev/null @@ -1,7 +0,0 @@ -. -= - -.. toctree:: - :maxdepth: 4 - - pyspark diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst index e95d19e97f151..4548b8739ed91 100644 --- a/python/docs/pyspark.mllib.rst +++ b/python/docs/pyspark.mllib.rst @@ -20,6 +20,14 @@ pyspark.mllib.clustering module :undoc-members: :show-inheritance: +pyspark.mllib.feature module +------------------------------- + +.. automodule:: pyspark.mllib.feature + :members: + :undoc-members: + :show-inheritance: + pyspark.mllib.linalg module --------------------------- diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index a68bd62433085..e81be3b6cb796 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -7,8 +7,9 @@ Subpackages .. toctree:: :maxdepth: 1 - pyspark.mllib pyspark.sql + pyspark.streaming + pyspark.mllib Contents -------- diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst new file mode 100644 index 0000000000000..5024d694b668f --- /dev/null +++ b/python/docs/pyspark.streaming.rst @@ -0,0 +1,10 @@ +pyspark.streaming module +================== + +Module contents +--------------- + +.. automodule:: pyspark.streaming + :members: + :undoc-members: + :show-inheritance: diff --git a/python/epydoc.conf b/python/epydoc.conf deleted file mode 100644 index 8593e08deda19..0000000000000 --- a/python/epydoc.conf +++ /dev/null @@ -1,38 +0,0 @@ -[epydoc] # Epydoc section marker (required by ConfigParser) - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Information about the project. -name: Spark 1.0.0 Python API Docs -url: http://spark.apache.org - -# The list of modules to document. Modules can be named using -# dotted names, module filenames, or package directory names. -# This option may be repeated. -modules: pyspark - -# Write html output to the directory "apidocs" -output: html -target: docs/ - -private: no - -exclude: pyspark.cloudpickle pyspark.worker pyspark.join - pyspark.java_gateway pyspark.examples pyspark.shell pyspark.tests - pyspark.rddsampler pyspark.daemon - pyspark.mllib.tests pyspark.shuffle diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 1a2e774738fe7..9556e4718e585 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -20,45 +20,23 @@ Public classes: - - L{SparkContext} + - :class:`SparkContext`: Main entry point for Spark functionality. - - L{RDD} + - L{RDD} A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. - - L{Broadcast} + - L{Broadcast} A broadcast variable that gets reused across tasks. - - L{Accumulator} + - L{Accumulator} An "add-only" shared variable that tasks can only add values to. - - L{SparkConf} + - L{SparkConf} For configuring Spark. - - L{SparkFiles} + - L{SparkFiles} Access files shipped with jobs. - - L{StorageLevel} + - L{StorageLevel} Finer-grained cache persistence levels. -Spark SQL: - - L{SQLContext} - Main entry point for SQL functionality. - - L{SchemaRDD} - A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. - - L{Row} - A Row of data returned by a Spark SQL query. - -Hive: - - L{HiveContext} - Main entry point for accessing data stored in Apache Hive.. """ -# The following block allows us to import python's random instead of mllib.random for scripts in -# mllib that depend on top level pyspark packages, which transitively depend on python's random. -# Since Python's import logic looks for modules in the current package first, we eliminate -# mllib.random as a candidate for C{import random} by removing the first search path, the script's -# location, in order to force the loader to look in Python's top-level modules for C{random}. -import sys -s = sys.path.pop(0) -import random -sys.path.insert(0, s) - from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index f124dc6c07575..6b8a8b256a891 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -15,21 +15,10 @@ # limitations under the License. # -""" ->>> from pyspark.context import SparkContext ->>> sc = SparkContext('local', 'test') ->>> b = sc.broadcast([1, 2, 3, 4, 5]) ->>> b.value -[1, 2, 3, 4, 5] ->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() -[1, 2, 3, 4, 5, 1, 2, 3, 4, 5] ->>> b.unpersist() - ->>> large_broadcast = sc.broadcast(list(range(10000))) -""" import os - -from pyspark.serializers import CompressedSerializer, PickleSerializer +import cPickle +import gc +from tempfile import NamedTemporaryFile __all__ = ['Broadcast'] @@ -49,44 +38,88 @@ def _from_id(bid): class Broadcast(object): """ - A broadcast variable created with - L{SparkContext.broadcast()}. + A broadcast variable created with L{SparkContext.broadcast()}. Access its value through C{.value}. + + Examples: + + >>> from pyspark.context import SparkContext + >>> sc = SparkContext('local', 'test') + >>> b = sc.broadcast([1, 2, 3, 4, 5]) + >>> b.value + [1, 2, 3, 4, 5] + >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect() + [1, 2, 3, 4, 5, 1, 2, 3, 4, 5] + >>> b.unpersist() + + >>> large_broadcast = sc.broadcast(range(10000)) """ - def __init__(self, bid, value, java_broadcast=None, - pickle_registry=None, path=None): + def __init__(self, sc=None, value=None, pickle_registry=None, path=None): """ - Should not be called directly by users -- use - L{SparkContext.broadcast()} + Should not be called directly by users -- use L{SparkContext.broadcast()} instead. """ - self.bid = bid - if path is None: - self._value = value - self._jbroadcast = java_broadcast - self._pickle_registry = pickle_registry - self.path = path + if sc is not None: + f = NamedTemporaryFile(delete=False, dir=sc._temp_dir) + self._path = self.dump(value, f) + self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path) + self._pickle_registry = pickle_registry + else: + self._jbroadcast = None + self._path = path + + def dump(self, value, f): + if isinstance(value, basestring): + if isinstance(value, unicode): + f.write('U') + value = value.encode('utf8') + else: + f.write('S') + f.write(value) + else: + f.write('P') + cPickle.dump(value, f, 2) + f.close() + return f.name + + def load(self, path): + with open(path, 'rb', 1 << 20) as f: + flag = f.read(1) + data = f.read() + if flag == 'P': + # cPickle.loads() may create lots of objects, disable GC + # temporary for better performance + gc.disable() + try: + return cPickle.loads(data) + finally: + gc.enable() + else: + return data.decode('utf8') if flag == 'U' else data @property def value(self): """ Return the broadcasted value """ - if not hasattr(self, "_value") and self.path is not None: - ser = CompressedSerializer(PickleSerializer()) - self._value = ser.load_stream(open(self.path)).next() + if not hasattr(self, "_value") and self._path is not None: + self._value = self.load(self._path) return self._value def unpersist(self, blocking=False): """ Delete cached copies of this broadcast on the executors. """ + if self._jbroadcast is None: + raise Exception("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) - os.unlink(self.path) + os.unlink(self._path) def __reduce__(self): + if self._jbroadcast is None: + raise Exception("Broadcast can only be serialized in driver") self._pickle_registry.add(self) - return (_from_id, (self.bid, )) + return _from_id, (self._jbroadcast.id(),) if __name__ == "__main__": diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index b64875a3f495a..dc7cd0bce56f3 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -83,11 +83,11 @@ def __init__(self, loadDefaults=True, _jvm=None, _jconf=None): """ Create a new Spark configuration. - @param loadDefaults: whether to load values from Java system + :param loadDefaults: whether to load values from Java system properties (True by default) - @param _jvm: internal parameter used to pass a handle to the + :param _jvm: internal parameter used to pass a handle to the Java VM; does not need to be set by users - @param _jconf: Optionally pass in an existing SparkConf handle + :param _jconf: Optionally pass in an existing SparkConf handle to use its parameters """ if _jconf: @@ -139,7 +139,7 @@ def setAll(self, pairs): """ Set multiple parameters, passed as a list of key-value pairs. - @param pairs: list of key-value pairs to set + :param pairs: list of key-value pairs to set """ for (k, v) in pairs: self._jconf.set(k, v) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index e9418320ff781..ed7351d60cff2 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -29,7 +29,7 @@ from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ - PairDeserializer, CompressedSerializer + PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -43,7 +43,6 @@ # These are special default configs for PySpark, they will overwrite # the default ones for Spark if they are not configured by user. DEFAULT_CONFIGS = { - "spark.serializer": "org.apache.spark.serializer.KryoSerializer", "spark.serializer.objectStreamReset": 100, "spark.rdd.compress": True, } @@ -64,30 +63,30 @@ class SparkContext(object): _active_spark_context = None _lock = Lock() _python_includes = None # zip and egg files that need to be added to PYTHONPATH - _default_batch_size_for_serialized_input = 10 def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024, serializer=PickleSerializer(), conf=None, - gateway=None): + environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, + gateway=None, jsc=None): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. - @param master: Cluster URL to connect to + :param master: Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). - @param appName: A name for your job, to display on the cluster web UI. - @param sparkHome: Location where Spark is installed on cluster nodes. - @param pyFiles: Collection of .zip or .py files to send to the cluster + :param appName: A name for your job, to display on the cluster web UI. + :param sparkHome: Location where Spark is installed on cluster nodes. + :param pyFiles: Collection of .zip or .py files to send to the cluster and add to PYTHONPATH. These can be paths on the local file system or HDFS, HTTP, HTTPS, or FTP URLs. - @param environment: A dictionary of environment variables to set on + :param environment: A dictionary of environment variables to set on worker nodes. - @param batchSize: The number of Python objects represented as a single - Java object. Set 1 to disable batching or -1 to use an - unlimited batch size. - @param serializer: The serializer for RDDs. - @param conf: A L{SparkConf} object setting Spark properties. - @param gateway: Use an existing gateway and JVM, otherwise a new JVM + :param batchSize: The number of Python objects represented as a single + Java object. Set 1 to disable batching, 0 to automatically choose + the batch size based on object sizes, or -1 to use an unlimited + batch size + :param serializer: The serializer for RDDs. + :param conf: A L{SparkConf} object setting Spark properties. + :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. @@ -103,20 +102,20 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf) + conf, jsc) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf): + conf, jsc): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size self._unbatched_serializer = serializer - if batchSize == 1: - self.serializer = self._unbatched_serializer + if batchSize == 0: + self.serializer = AutoBatchedSerializer(self._unbatched_serializer) else: self.serializer = BatchedSerializer(self._unbatched_serializer, batchSize) @@ -151,7 +150,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self.environment[varName] = v # Create the Java SparkContext through Py4J - self._jsc = self._initialize_context(self._conf._jconf) + self._jsc = jsc or self._initialize_context(self._conf._jconf) # Create a single Accumulator in Java that we'll send all our updates through; # they will be passed back to us through a TCP server @@ -212,8 +211,6 @@ def _ensure_initialized(cls, instance=None, gateway=None): SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile - SparkContext._jvm.SerDeUtil.initialize() - SparkContext._jvm.SerDe.initialize() if instance: if (SparkContext._active_spark_context and @@ -292,12 +289,29 @@ def stop(self): def parallelize(self, c, numSlices=None): """ - Distribute a local Python collection to form an RDD. + Distribute a local Python collection to form an RDD. Using xrange + is recommended if the input represents a range for performance. - >>> sc.parallelize(range(5), 5).glom().collect() - [[0], [1], [2], [3], [4]] + >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect() + [[0], [2], [3], [4], [6]] + >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect() + [[], [0], [], [2], [4]] """ - numSlices = numSlices or self.defaultParallelism + numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism + if isinstance(c, xrange): + size = len(c) + if size == 0: + return self.parallelize([], numSlices) + step = c[1] - c[0] if size > 1 else 1 + start0 = c[0] + + def getStart(split): + return start0 + (split * size / numSlices) * step + + def f(split, iterator): + return xrange(getStart(split), getStart(split + 1), step) + + return self.parallelize([], numSlices).mapPartitionsWithIndex(f) # Calling the Java parallelize() method with an ArrayList is too slow, # because it sends O(n) Py4J commands. As an alternative, serialized # objects are written to a file and loaded through textFile(). @@ -305,12 +319,8 @@ def parallelize(self, c, numSlices=None): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self._batchSize) - if batchSize > 1: - serializer = BatchedSerializer(self._unbatched_serializer, - batchSize) - else: - serializer = self._unbatched_serializer + batchSize = max(1, min(len(c) // numSlices, self._batchSize)) + serializer = BatchedSerializer(self._unbatched_serializer, batchSize) serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile @@ -328,8 +338,7 @@ def pickleFile(self, name, minPartitions=None): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ minPartitions = minPartitions or self.defaultMinPartitions - return RDD(self._jsc.objectFile(name, minPartitions), self, - BatchedSerializer(PickleSerializer())) + return RDD(self._jsc.objectFile(name, minPartitions), self) def textFile(self, name, minPartitions=None, use_unicode=True): """ @@ -396,6 +405,36 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True): return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode))) + def binaryFiles(self, path, minPartitions=None): + """ + :: Experimental :: + + Read a directory of binary files from HDFS, a local file system + (available on all nodes), or any Hadoop-supported file system URI + as a byte array. Each file is read as a single record and returned + in a key-value pair, where the key is the path of each file, the + value is the content of each file. + + Note: Small files are preferred, large file is also allowable, but + may cause bad performance. + """ + minPartitions = minPartitions or self.defaultMinPartitions + return RDD(self._jsc.binaryFiles(path, minPartitions), self, + PairDeserializer(UTF8Deserializer(), NoOpSerializer())) + + def binaryRecords(self, path, recordLength): + """ + :: Experimental :: + + Load data from a flat binary file, assuming each record is a set of numbers + with the specified numerical format (see ByteBuffer), and the number of + bytes per record is constant. + + :param path: Directory to the input data files + :param recordLength: The length at which to split the records + """ + return RDD(self._jsc.binaryRecords(path, recordLength), self, NoOpSerializer()) + def _dictToJavaMap(self, d): jm = self._jvm.java.util.HashMap() if not d: @@ -405,38 +444,37 @@ def _dictToJavaMap(self, d): return jm def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, - valueConverter=None, minSplits=None, batchSize=None): + valueConverter=None, minSplits=None, batchSize=0): """ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. The mechanism is as follows: + 1. A Java RDD is created from the SequenceFile or other InputFormat, and the key and value Writable classes 2. Serialization is attempted via Pyrolite pickling 3. If this fails, the fallback is to call 'toString' on each key and value 4. C{PickleSerializer} is used to deserialize pickled objects on the Python side - @param path: path to sequncefile - @param keyClass: fully qualified classname of key Writable class + :param path: path to sequncefile + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: - @param valueConverter: - @param minSplits: minimum splits in dataset + :param keyConverter: + :param valueConverter: + :param minSplits: minimum splits in dataset (default min(2, sc.defaultParallelism)) - @param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + :param batchSize: The number of Python objects represented as a single + Java object. (default 0, choose batchSize automatically) """ minSplits = minSplits or min(self.defaultParallelism, 2) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, keyConverter, valueConverter, minSplits, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -445,59 +483,55 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + :param batchSize: The number of Python objects represented as a single + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + :param batchSize: The number of Python objects represented as a single + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, a local file system (available on all nodes), or any Hadoop-supported file system URI. @@ -506,56 +540,52 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter= A Hadoop configuration can be passed in as a Python dict. This will be converted into a Configuration in Java. - @param path: path to Hadoop file - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param path: path to Hadoop file + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + :param batchSize: The number of Python objects represented as a single + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, - valueConverter=None, conf=None, batchSize=None): + valueConverter=None, conf=None, batchSize=0): """ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary Hadoop configuration, which is passed in as a Python dict. This will be converted into a Configuration in Java. The mechanism is the same as for sc.sequenceFile. - @param inputFormatClass: fully qualified classname of Hadoop InputFormat + :param inputFormatClass: fully qualified classname of Hadoop InputFormat (e.g. "org.apache.hadoop.mapred.TextInputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.Text") - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.LongWritable") - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop configuration, passed in as a dict (None by default) - @param batchSize: The number of Python objects represented as a single - Java object. (default sc._default_batch_size_for_serialized_input) + :param batchSize: The number of Python objects represented as a single + Java object. (default 0, choose batchSize automatically) """ jconf = self._dictToJavaMap(conf) - batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input) - ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer() jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) - return RDD(jrdd, self, ser) + return RDD(jrdd, self) def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) @@ -594,14 +624,7 @@ def broadcast(self, value): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - ser = CompressedSerializer(PickleSerializer()) - # pass large object by py4j is very slow and need much memory - tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir) - ser.dump_stream([value], tempFile) - tempFile.close() - jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name) - return Broadcast(jbroadcast.id(), None, jbroadcast, - self._pickled_broadcast_vars, tempFile.name) + return Broadcast(self, value, self._pickled_broadcast_vars) def accumulator(self, value, accum_param=None): """ @@ -835,7 +858,7 @@ def _test(): import doctest import tempfile globs = globals().copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 64d6202acb27d..f09587f211708 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -26,7 +26,7 @@ import gc from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN -from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN +from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.worker import main as worker_main from pyspark.serializers import read_int, write_int @@ -46,6 +46,9 @@ def worker(sock): signal.signal(SIGHUP, SIG_DFL) signal.signal(SIGCHLD, SIG_DFL) signal.signal(SIGTERM, SIG_DFL) + # restore the handler for SIGINT, + # it's useful for debugging (show the stacktrace before exit) + signal.signal(SIGINT, signal.default_int_handler) # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because @@ -59,8 +62,7 @@ def worker(sock): exit_code = compute_real_exit_code(exc.code) finally: outfile.flush() - if exit_code: - os._exit(exit_code) + return exit_code # Cleanup zombie children @@ -157,10 +159,13 @@ def handle_sigterm(*args): outfile.flush() outfile.close() while True: - worker(sock) - if not reuse: + code = worker(sock) + if not reuse or code: # wait for closing - while sock.recv(1024): + try: + while sock.recv(1024): + pass + except Exception: pass break gc.collect() diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 9c70fa5c16d0c..a975dc19cb78e 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -45,7 +45,9 @@ def launch_gateway(): # Don't send ctrl-c / SIGINT to the Java gateway: def preexec_func(): signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func) + env = dict(os.environ) + env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits + proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env) else: # preexec_fn not supported on Windows proc = Popen(command, stdout=PIPE, stdin=PIPE) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index 4149f54931d1f..5030a655fcbba 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -24,3 +24,37 @@ import numpy if numpy.version.version < '1.4': raise Exception("MLlib requires NumPy 1.4+") + +__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random', + 'recommendation', 'regression', 'stat', 'tree', 'util'] + +import sys +import rand as random +random.__name__ = 'random' +random.RandomRDDs.__module__ = __name__ + '.random' + + +class RandomModuleHook(object): + """ + Hook to import pyspark.mllib.random + """ + fullname = __name__ + '.random' + + def find_module(self, name, path=None): + # skip all other modules + if not name.startswith(self.fullname): + return + return self + + def load_module(self, name): + if name == self.fullname: + return random + + cname = name.rsplit('.', 1)[-1] + try: + return getattr(random, cname) + except AttributeError: + raise ImportError + + +sys.meta_path.append(RandomModuleHook()) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index ac142fb49a90c..f14d0ed11cbbb 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -20,96 +20,200 @@ import numpy from numpy import array -from pyspark import SparkContext, PickleSerializer +from pyspark import RDD +from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper -__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel', - 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] +__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', + 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] -class LogisticRegressionModel(LinearModel): +class LinearBinaryClassificationModel(LinearModel): + """ + Represents a linear binary classification model that predicts to whether an + example is positive (1.0) or negative (0.0). + """ + def __init__(self, weights, intercept): + super(LinearBinaryClassificationModel, self).__init__(weights, intercept) + self._threshold = None + + def setThreshold(self, value): + """ + :: Experimental :: + + Sets the threshold that separates positive predictions from negative + predictions. An example with prediction score greater than or equal + to this threshold is identified as an positive, and negative otherwise. + """ + self._threshold = value + + def clearThreshold(self): + """ + :: Experimental :: + + Clears the threshold so that `predict` will output raw prediction scores. + """ + self._threshold = None + + def predict(self, test): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + raise NotImplementedError + + +class LogisticRegressionModel(LinearBinaryClassificationModel): """A linear binary classification model derived from logistic regression. >>> data = [ - ... LabeledPoint(0.0, [0.0]), - ... LabeledPoint(1.0, [1.0]), - ... LabeledPoint(1.0, [2.0]), - ... LabeledPoint(1.0, [3.0]) + ... LabeledPoint(0.0, [0.0, 1.0]), + ... LabeledPoint(1.0, [1.0, 0.0]), ... ] >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data)) - >>> lrm.predict(array([1.0])) > 0 - True - >>> lrm.predict(array([0.0])) <= 0 - True + >>> lrm.predict([1.0, 0.0]) + 1 + >>> lrm.predict([0.0, 1.0]) + 0 + >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect() + [1, 0] + >>> lrm.clearThreshold() + >>> lrm.predict([0.0, 1.0]) + 0.123... + >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), - ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data)) - >>> lrm.predict(array([0.0, 1.0])) > 0 - True - >>> lrm.predict(array([0.0, 0.0])) <= 0 - True - >>> lrm.predict(SparseVector(2, {1: 1.0})) > 0 - True - >>> lrm.predict(SparseVector(2, {1: 0.0})) <= 0 - True + >>> lrm.predict(array([0.0, 1.0])) + 1 + >>> lrm.predict(array([1.0, 0.0])) + 0 + >>> lrm.predict(SparseVector(2, {1: 1.0})) + 1 + >>> lrm.predict(SparseVector(2, {0: 1.0})) + 0 """ + def __init__(self, weights, intercept): + super(LogisticRegressionModel, self).__init__(weights, intercept) + self._threshold = 0.5 def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + if isinstance(x, RDD): + return x.map(lambda v: self.predict(v)) + + x = _convert_to_vector(x) margin = self.weights.dot(x) + self._intercept if margin > 0: prob = 1 / (1 + exp(-margin)) else: exp_margin = exp(margin) prob = exp_margin / (1 + exp_margin) - return 1 if prob > 0.5 else 0 + if self._threshold is None: + return prob + else: + return 1 if prob > self._threshold else 0 class LogisticRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType="none", intercept=False): + initialWeights=None, regParam=0.01, regType="l2", intercept=False): """ Train a logistic regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data, an RDD of LabeledPoint. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 0.01). + :param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") - @param intercept: Boolean parameter which indicates the use + + :Allowed values: + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization + + (default: "l2") + + :param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not). """ - sc = data.context + def train(rdd, i): + return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), + float(step), float(miniBatchFraction), i, float(regParam), regType, + bool(intercept)) + + return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) + + +class LogisticRegressionWithLBFGS(object): + + @classmethod + def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", + intercept=False, corrections=10, tolerance=1e-4): + """ + Train a logistic regression model on the given data. - def train(jdata, i): - return sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD( - jdata, iterations, step, miniBatchFraction, i, regParam, regType, intercept) + :param data: The training data, an RDD of LabeledPoint. + :param iterations: The number of iterations (default: 100). + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 0.01). + :param regType: The type of regularizer used for training + our model. + + :Allowed values: + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization + + (default: "l2") + + :param intercept: Boolean parameter which indicates the use + or not of the augmented representation for + training data (i.e. whether bias features + are activated or not). + :param corrections: The number of corrections used in the LBFGS + update (default: 10). + :param tolerance: The convergence tolerance of iterations for + L-BFGS (default: 1e-4). + + >>> data = [ + ... LabeledPoint(0.0, [0.0, 1.0]), + ... LabeledPoint(1.0, [1.0, 0.0]), + ... ] + >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data)) + >>> lrm.predict([1.0, 0.0]) + 1 + >>> lrm.predict([0.0, 1.0]) + 0 + """ + def train(rdd, i): + return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i, + float(regParam), str(regType), bool(intercept), int(corrections), + float(tolerance)) - return _regression_train_wrapper(sc, train, LogisticRegressionModel, data, - initialWeights) + return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) -class SVMModel(LinearModel): +class SVMModel(LinearBinaryClassificationModel): """A support vector machine. @@ -120,8 +224,14 @@ class SVMModel(LinearModel): ... LabeledPoint(1.0, [3.0]) ... ] >>> svm = SVMWithSGD.train(sc.parallelize(data)) - >>> svm.predict(array([1.0])) > 0 - True + >>> svm.predict([1.0]) + 1 + >>> svm.predict(sc.parallelize([[1.0]])).collect() + [1] + >>> svm.clearThreshold() + >>> svm.predict(array([1.0])) + 1.25... + >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {0: -1.0})), ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), @@ -129,52 +239,68 @@ class SVMModel(LinearModel): ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data)) - >>> svm.predict(SparseVector(2, {1: 1.0})) > 0 - True - >>> svm.predict(SparseVector(2, {0: -1.0})) <= 0 - True + >>> svm.predict(SparseVector(2, {1: 1.0})) + 1 + >>> svm.predict(SparseVector(2, {0: -1.0})) + 0 """ + def __init__(self, weights, intercept): + super(SVMModel, self).__init__(weights, intercept) + self._threshold = 0.0 def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + if isinstance(x, RDD): + return x.map(lambda v: self.predict(v)) + + x = _convert_to_vector(x) margin = self.weights.dot(x) + self.intercept - return 1 if margin >= 0 else 0 + if self._threshold is None: + return margin + else: + return 1 if margin > self._threshold else 0 class SVMWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, - miniBatchFraction=1.0, initialWeights=None, regType="none", intercept=False): + def train(cls, data, iterations=100, step=1.0, regParam=0.01, + miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): """ Train a support vector machine on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data, an RDD of LabeledPoint. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param regParam: The regularizer parameter (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param regParam: The regularizer parameter (default: 0.01). + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") - @param intercept: Boolean parameter which indicates the use + + :Allowed values: + - "l1" for using L1 regularization + - "l2" for using L2 regularization + - None for no regularization + + (default: "l2") + + :param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not). """ - sc = data.context - - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD( - jrdd, iterations, step, regParam, miniBatchFraction, i, regType, intercept) + def train(rdd, i): + return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i, regType, + bool(intercept)) - return _regression_train_wrapper(sc, train, SVMModel, data, initialWeights) + return _regression_train_wrapper(train, SVMModel, data, initialWeights) class NaiveBayesModel(object): @@ -196,6 +322,8 @@ class NaiveBayesModel(object): 0.0 >>> model.predict(array([1.0, 0.0])) 1.0 + >>> model.predict(sc.parallelize([[1.0, 0.0]])).collect() + [1.0] >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {1: 0.0})), ... LabeledPoint(0.0, SparseVector(2, {1: 1.0})), @@ -214,7 +342,9 @@ def __init__(self, labels, pi, theta): self.theta = theta def predict(self, x): - """Return the most likely class for a data vector x""" + """Return the most likely class for a data vector or an RDD of vectors""" + if isinstance(x, RDD): + return x.map(lambda v: self.predict(v)) x = _convert_to_vector(x) return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))] @@ -232,20 +362,21 @@ def train(cls, data, lambda_=1.0): classification. By making every vector a 0-1 vector, it can also be used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}). - @param data: RDD of NumPy vectors, one per element, where the first - coordinate is the label and the rest is the feature vector - (e.g. a count vector). - @param lambda_: The smoothing parameter + :param data: RDD of LabeledPoint. + :param lambda_: The smoothing parameter """ - sc = data.context - jlist = sc._jvm.PythonMLLibAPI().trainNaiveBayes(data._to_java_object_rdd(), lambda_) - labels, pi, theta = PickleSerializer().loads(str(sc._jvm.SerDe.dumps(jlist))) + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("`data` should be an RDD of LabeledPoint") + labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) def _test(): import doctest - globs = globals().copy() + from pyspark import SparkContext + import pyspark.mllib.classification + globs = pyspark.mllib.classification.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 12c56022717a5..e2492eef5bd6a 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -16,7 +16,7 @@ # from pyspark import SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer +from pyspark.mllib.common import callMLlibFunc, callJavaFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector __all__ = ['KMeansModel', 'KMeans'] @@ -80,14 +80,9 @@ class KMeans(object): @classmethod def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): """Train a k-means clustering model.""" - sc = rdd.context - ser = PickleSerializer() - # cache serialized data to avoid objects over head in JVM - cached = rdd.map(_convert_to_vector)._reserialize(AutoBatchedSerializer(ser)).cache() - model = sc._jvm.PythonMLLibAPI().trainKMeansModel( - cached._to_java_object_rdd(), k, maxIterations, runs, initializationMode) - bytes = sc._jvm.SerDe.dumps(model.clusterCenters()) - centers = ser.loads(str(bytes)) + model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, + runs, initializationMode) + centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py new file mode 100644 index 0000000000000..33c49e2399908 --- /dev/null +++ b/python/pyspark/mllib/common.py @@ -0,0 +1,138 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import py4j.protocol +from py4j.protocol import Py4JJavaError +from py4j.java_gateway import JavaObject +from py4j.java_collections import MapConverter, ListConverter, JavaArray, JavaList + +from pyspark import RDD, SparkContext +from pyspark.serializers import PickleSerializer, AutoBatchedSerializer + + +# Hack for support float('inf') in Py4j +_old_smart_decode = py4j.protocol.smart_decode + +_float_str_mapping = { + 'nan': 'NaN', + 'inf': 'Infinity', + '-inf': '-Infinity', +} + + +def _new_smart_decode(obj): + if isinstance(obj, float): + s = unicode(obj) + return _float_str_mapping.get(s, s) + return _old_smart_decode(obj) + +py4j.protocol.smart_decode = _new_smart_decode + + +_picklable_classes = [ + 'LinkedList', + 'SparseVector', + 'DenseVector', + 'DenseMatrix', + 'Rating', + 'LabeledPoint', +] + + +# this will call the MLlib version of pythonToJava() +def _to_java_object_rdd(rdd): + """ Return an JavaRDD of Object by unpickling + + It will convert each Python object into Java object by Pyrolite, whenever the + RDD is serialized in batch or not. + """ + rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer())) + return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True) + + +def _py2java(sc, obj): + """ Convert Python object into Java """ + if isinstance(obj, RDD): + obj = _to_java_object_rdd(obj) + elif isinstance(obj, SparkContext): + obj = obj._jsc + elif isinstance(obj, dict): + obj = MapConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, (list, tuple)): + obj = ListConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, JavaObject): + pass + elif isinstance(obj, (int, long, float, bool, basestring)): + pass + else: + bytes = bytearray(PickleSerializer().dumps(obj)) + obj = sc._jvm.SerDe.loads(bytes) + return obj + + +def _java2py(sc, r): + if isinstance(r, JavaObject): + clsName = r.getClass().getSimpleName() + # convert RDD into JavaRDD + if clsName != 'JavaRDD' and clsName.endswith("RDD"): + r = r.toJavaRDD() + clsName = 'JavaRDD' + + if clsName == 'JavaRDD': + jrdd = sc._jvm.SerDe.javaToPython(r) + return RDD(jrdd, sc) + + if clsName in _picklable_classes: + r = sc._jvm.SerDe.dumps(r) + elif isinstance(r, (JavaArray, JavaList)): + try: + r = sc._jvm.SerDe.dumps(r) + except Py4JJavaError: + pass # not pickable + + if isinstance(r, bytearray): + r = PickleSerializer().loads(str(r)) + return r + + +def callJavaFunc(sc, func, *args): + """ Call Java Function """ + args = [_py2java(sc, a) for a in args] + return _java2py(sc, func(*args)) + + +def callMLlibFunc(name, *args): + """ Call API in PythonMLLibAPI """ + sc = SparkContext._active_spark_context + api = getattr(sc._jvm.PythonMLLibAPI(), name) + return callJavaFunc(sc, api, *args) + + +class JavaModelWrapper(object): + """ + Wrapper for the model in JVM + """ + def __init__(self, java_model): + self._sc = SparkContext._active_spark_context + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def call(self, name, *a): + """Call method of java_model""" + return callJavaFunc(self._sc, getattr(self._java_model, name), *a) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py new file mode 100644 index 0000000000000..8cb992df2d9c7 --- /dev/null +++ b/python/pyspark/mllib/feature.py @@ -0,0 +1,416 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for feature in MLlib. +""" +from __future__ import absolute_import + +import sys +import warnings +import random + +from py4j.protocol import Py4JJavaError + +from pyspark import RDD, SparkContext +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import Vectors, _convert_to_vector + +__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', + 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel'] + + +class VectorTransformer(object): + """ + :: DeveloperApi :: + + Base class for transformation of a vector or RDD of vector + """ + def transform(self, vector): + """ + Applies transformation on a vector. + + :param vector: vector to be transformed. + """ + raise NotImplementedError + + +class Normalizer(VectorTransformer): + """ + :: Experimental :: + + Normalizes samples individually to unit L\ :sup:`p`\ norm + + For any 1 <= `p` <= float('inf'), normalizes samples using + sum(abs(vector). :sup:`p`) :sup:`(1/p)` as norm. + + For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization. + + >>> v = Vectors.dense(range(3)) + >>> nor = Normalizer(1) + >>> nor.transform(v) + DenseVector([0.0, 0.3333, 0.6667]) + + >>> rdd = sc.parallelize([v]) + >>> nor.transform(rdd).collect() + [DenseVector([0.0, 0.3333, 0.6667])] + + >>> nor2 = Normalizer(float("inf")) + >>> nor2.transform(v) + DenseVector([0.0, 0.5, 1.0]) + """ + def __init__(self, p=2.0): + """ + :param p: Normalization in L^p^ space, p = 2 by default. + """ + assert p >= 1.0, "p should be greater than 1.0" + self.p = float(p) + + def transform(self, vector): + """ + Applies unit length normalization on a vector. + + :param vector: vector or RDD of vector to be normalized. + :return: normalized vector. If the norm of the input is zero, it + will return the input vector. + """ + sc = SparkContext._active_spark_context + assert sc is not None, "SparkContext should be initialized first" + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) + return callMLlibFunc("normalizeVector", self.p, vector) + + +class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): + """ + Wrapper for the model in JVM + """ + + def transform(self, vector): + if isinstance(vector, RDD): + vector = vector.map(_convert_to_vector) + else: + vector = _convert_to_vector(vector) + return self.call("transform", vector) + + +class StandardScalerModel(JavaVectorTransformer): + """ + :: Experimental :: + + Represents a StandardScaler model that can transform vectors. + """ + def transform(self, vector): + """ + Applies standardization transformation on a vector. + + :param vector: Vector or RDD of Vector to be standardized. + :return: Standardized vector. If the variance of a column is zero, + it will return default `0.0` for the column with zero variance. + """ + return JavaVectorTransformer.transform(self, vector) + + +class StandardScaler(object): + """ + :: Experimental :: + + Standardizes features by removing the mean and scaling to unit + variance using column summary statistics on the samples in the + training set. + + >>> vs = [Vectors.dense([-2.0, 2.3, 0]), Vectors.dense([3.8, 0.0, 1.9])] + >>> dataset = sc.parallelize(vs) + >>> standardizer = StandardScaler(True, True) + >>> model = standardizer.fit(dataset) + >>> result = model.transform(dataset) + >>> for r in result.collect(): r + DenseVector([-0.7071, 0.7071, -0.7071]) + DenseVector([0.7071, -0.7071, 0.7071]) + """ + def __init__(self, withMean=False, withStd=True): + """ + :param withMean: False by default. Centers the data with mean + before scaling. It will build a dense output, so this + does not work on sparse input and will raise an exception. + :param withStd: True by default. Scales the data to unit standard + deviation. + """ + if not (withMean or withStd): + warnings.warn("Both withMean and withStd are false. The model does nothing.") + self.withMean = withMean + self.withStd = withStd + + def fit(self, dataset): + """ + Computes the mean and variance and stores as a model to be used for later scaling. + + :param data: The data used to compute the mean and variance to build + the transformation model. + :return: a StandardScalarModel + """ + dataset = dataset.map(_convert_to_vector) + jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset) + return StandardScalerModel(jmodel) + + +class HashingTF(object): + """ + :: Experimental :: + + Maps a sequence of terms to their term frequencies using the hashing trick. + + Note: the terms must be hashable (can not be dict/set/list...). + + >>> htf = HashingTF(100) + >>> doc = "a a b b c d".split(" ") + >>> htf.transform(doc) + SparseVector(100, {1: 1.0, 14: 1.0, 31: 2.0, 44: 2.0}) + """ + def __init__(self, numFeatures=1 << 20): + """ + :param numFeatures: number of features (default: 2^20) + """ + self.numFeatures = numFeatures + + def indexOf(self, term): + """ Returns the index of the input term. """ + return hash(term) % self.numFeatures + + def transform(self, document): + """ + Transforms the input document (list of terms) to term frequency vectors, + or transform the RDD of document to RDD of term frequency vectors. + """ + if isinstance(document, RDD): + return document.map(self.transform) + + freq = {} + for term in document: + i = self.indexOf(term) + freq[i] = freq.get(i, 0) + 1.0 + return Vectors.sparse(self.numFeatures, freq.items()) + + +class IDFModel(JavaVectorTransformer): + """ + Represents an IDF model that can transform term frequency vectors. + """ + def transform(self, dataset): + """ + Transforms term frequency (TF) vectors to TF-IDF vectors. + + If `minDocFreq` was set for the IDF calculation, + the terms which occur in fewer than `minDocFreq` + documents will have an entry of 0. + + :param dataset: an RDD of term frequency vectors + :return: an RDD of TF-IDF vectors + """ + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") + return JavaVectorTransformer.transform(self, dataset) + + +class IDF(object): + """ + :: Experimental :: + + Inverse document frequency (IDF). + + The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, + where `m` is the total number of documents and `d(t)` is the number + of documents that contain term `t`. + + This implementation supports filtering out terms which do not appear + in a minimum number of documents (controlled by the variable `minDocFreq`). + For terms that are not in at least `minDocFreq` documents, the IDF is + found as 0, resulting in TF-IDFs of 0. + + >>> n = 4 + >>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)), + ... Vectors.dense([0.0, 1.0, 2.0, 3.0]), + ... Vectors.sparse(n, [1], [1.0])] + >>> data = sc.parallelize(freqs) + >>> idf = IDF() + >>> model = idf.fit(data) + >>> tfidf = model.transform(data) + >>> for r in tfidf.collect(): r + SparseVector(4, {1: 0.0, 3: 0.5754}) + DenseVector([0.0, 0.0, 1.3863, 0.863]) + SparseVector(4, {1: 0.0}) + """ + def __init__(self, minDocFreq=0): + """ + :param minDocFreq: minimum of documents in which a term + should appear for filtering + """ + self.minDocFreq = minDocFreq + + def fit(self, dataset): + """ + Computes the inverse document frequency. + + :param dataset: an RDD of term frequency vectors + """ + if not isinstance(dataset, RDD): + raise TypeError("dataset should be an RDD of term frequency vectors") + jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset.map(_convert_to_vector)) + return IDFModel(jmodel) + + +class Word2VecModel(JavaVectorTransformer): + """ + class for Word2Vec model + """ + def transform(self, word): + """ + Transforms a word to its vector representation + + Note: local use only + + :param word: a word + :return: vector representation of word(s) + """ + try: + return self.call("transform", word) + except Py4JJavaError: + raise ValueError("%s not found" % word) + + def findSynonyms(self, word, num): + """ + Find synonyms of a word + + :param word: a word or a vector representation of word + :param num: number of synonyms to find + :return: array of (word, cosineSimilarity) + + Note: local use only + """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) + words, similarity = self.call("findSynonyms", word, num) + return zip(words, similarity) + + +class Word2Vec(object): + """ + Word2Vec creates vector representation of words in a text corpus. + The algorithm first constructs a vocabulary from the corpus + and then learns vector representation of words in the vocabulary. + The vector representation can be used as features in + natural language processing and machine learning algorithms. + + We used skip-gram model in our implementation and hierarchical softmax + method to train the model. The variable names in the implementation + matches the original C implementation. + + For original C implementation, see https://code.google.com/p/word2vec/ + For research papers, see + Efficient Estimation of Word Representations in Vector Space + and + Distributed Representations of Words and Phrases and their Compositionality. + + >>> sentence = "a b " * 100 + "a c " * 10 + >>> localDoc = [sentence, sentence] + >>> doc = sc.parallelize(localDoc).map(lambda line: line.split(" ")) + >>> model = Word2Vec().setVectorSize(10).setSeed(42L).fit(doc) + + >>> syms = model.findSynonyms("a", 2) + >>> [s[0] for s in syms] + [u'b', u'c'] + >>> vec = model.transform("a") + >>> syms = model.findSynonyms(vec, 2) + >>> [s[0] for s in syms] + [u'b', u'c'] + """ + def __init__(self): + """ + Construct Word2Vec instance + """ + self.vectorSize = 100 + self.learningRate = 0.025 + self.numPartitions = 1 + self.numIterations = 1 + self.seed = random.randint(0, sys.maxint) + + def setVectorSize(self, vectorSize): + """ + Sets vector size (default: 100). + """ + self.vectorSize = vectorSize + return self + + def setLearningRate(self, learningRate): + """ + Sets initial learning rate (default: 0.025). + """ + self.learningRate = learningRate + return self + + def setNumPartitions(self, numPartitions): + """ + Sets number of partitions (default: 1). Use a small number for accuracy. + """ + self.numPartitions = numPartitions + return self + + def setNumIterations(self, numIterations): + """ + Sets number of iterations (default: 1), which should be smaller than or equal to number of + partitions. + """ + self.numIterations = numIterations + return self + + def setSeed(self, seed): + """ + Sets random seed. + """ + self.seed = seed + return self + + def fit(self, data): + """ + Computes the vector representation of each word in vocabulary. + + :param data: training data. RDD of list of string + :return: Word2VecModel instance + """ + if not isinstance(data, RDD): + raise TypeError("data should be an RDD of list of string") + jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), + float(self.learningRate), int(self.numPartitions), + int(self.numIterations), long(self.seed)) + return Word2VecModel(jmodel) + + +def _test(): + import doctest + from pyspark import SparkContext + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + sys.path.pop(0) + _test() diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 51014a8ceb785..f7aa2b0cb04b3 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,7 +29,11 @@ import numpy as np -__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] +from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType + + +__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices'] if sys.version_info[:2] == (2, 7): @@ -98,7 +102,61 @@ def _vector_size(v): raise TypeError("Cannot treat type %s as a vector" % type(v)) +def _format_float(f, digits=4): + s = str(round(f, digits)) + if '.' in s: + s = s[:s.index('.') + 1 + digits] + return s + + +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + class Vector(object): + + __UDT__ = VectorUDT() + """ Abstract class for DenseVector and SparseVector """ @@ -115,12 +173,16 @@ class DenseVector(Vector): A dense vector represented by a value array. """ def __init__(self, ar): - if not isinstance(ar, array.array): - ar = array.array('d', ar) + if isinstance(ar, basestring): + ar = np.frombuffer(ar, dtype=np.float64) + elif not isinstance(ar, np.ndarray): + ar = np.array(ar, dtype=np.float64) + if ar.dtype != np.float64: + ar.astype(np.float64) self.array = ar def __reduce__(self): - return DenseVector, (self.array,) + return DenseVector, (self.array.tostring(),) def dot(self, other): """ @@ -149,9 +211,10 @@ def dot(self, other): ... AssertionError: dimension mismatch """ - if type(other) == np.ndarray and other.ndim > 1: - assert len(self) == other.shape[0], "dimension mismatch" - return np.dot(self.toArray(), other) + if type(other) == np.ndarray: + if other.ndim > 1: + assert len(self) == other.shape[0], "dimension mismatch" + return np.dot(self.array, other) elif _have_scipy and scipy.sparse.issparse(other): assert len(self) == other.shape[0], "dimension mismatch" return other.transpose().dot(self.toArray()) @@ -203,7 +266,7 @@ def squared_distance(self, other): return np.dot(diff, diff) def toArray(self): - return np.array(self.array) + return self.array def __getitem__(self, item): return self.array[item] @@ -215,10 +278,10 @@ def __str__(self): return "[" + ",".join([str(v) for v in self.array]) + "]" def __repr__(self): - return "DenseVector(%r)" % self.array + return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array)) def __eq__(self, other): - return isinstance(other, DenseVector) and self.array == other.array + return isinstance(other, DenseVector) and np.array_equal(self.array, other.array) def __ne__(self, other): return not self == other @@ -238,8 +301,8 @@ def __init__(self, size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print SparseVector(4, {1: 1.0, 3: 5.5}) @@ -256,18 +319,28 @@ def __init__(self, size, *args): if type(pairs) == dict: pairs = pairs.items() pairs = sorted(pairs) - self.indices = array.array('i', [p[0] for p in pairs]) - self.values = array.array('d', [p[1] for p in pairs]) + self.indices = np.array([p[0] for p in pairs], dtype=np.int32) + self.values = np.array([p[1] for p in pairs], dtype=np.float64) else: - assert len(args[0]) == len(args[1]), "index and value arrays not same length" - self.indices = array.array('i', args[0]) - self.values = array.array('d', args[1]) + if isinstance(args[0], basestring): + assert isinstance(args[1], str), "values should be string too" + if args[0]: + self.indices = np.frombuffer(args[0], np.int32) + self.values = np.frombuffer(args[1], np.float64) + else: + # np.frombuffer() doesn't work well with empty string in older version + self.indices = np.array([], dtype=np.int32) + self.values = np.array([], dtype=np.float64) + else: + self.indices = np.array(args[0], dtype=np.int32) + self.values = np.array(args[1], dtype=np.float64) + assert len(self.indices) == len(self.values), "index and value arrays not same length" for i in xrange(len(self.indices) - 1): if self.indices[i] >= self.indices[i + 1]: raise TypeError("indices array must be sorted") def __reduce__(self): - return (SparseVector, (self.size, self.indices, self.values)) + return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring())) def dot(self, other): """ @@ -403,8 +476,7 @@ def toArray(self): Returns a copy of this SparseVector as a 1-dimensional NumPy array. """ arr = np.zeros((self.size,), dtype=np.float64) - for i in xrange(self.indices.size): - arr[self.indices[i]] = self.values[i] + arr[self.indices] = self.values return arr def __len__(self): @@ -418,7 +490,8 @@ def __str__(self): def __repr__(self): inds = self.indices vals = self.values - entries = ", ".join(["{0}: {1}".format(inds[i], vals[i]) for i in xrange(len(inds))]) + entries = ", ".join(["{0}: {1}".format(inds[i], _format_float(vals[i])) + for i in xrange(len(inds))]) return "SparseVector({0}, {{{1}}})".format(self.size, entries) def __eq__(self, other): @@ -434,8 +507,8 @@ def __eq__(self, other): """ return (isinstance(other, self.__class__) and other.size == self.size - and other.indices == self.indices - and other.values == self.values) + and np.array_equal(other.indices, self.indices) + and np.array_equal(other.values, self.values)) def __ne__(self, other): return not self.__eq__(other) @@ -458,8 +531,8 @@ def sparse(size, *args): (index, value) pairs, or two separate arrays of indices and values (sorted by index). - @param size: Size of the vector. - @param args: Non-zero entries, as a dictionary, list of tupes, + :param size: Size of the vector. + :param args: Non-zero entries, as a dictionary, list of tupes, or two sorted lists containing indices and values. >>> print Vectors.sparse(4, {1: 1.0, 3: 5.5}) @@ -478,7 +551,7 @@ def dense(elements): returns a NumPy array. >>> Vectors.dense([1, 2, 3]) - DenseVector(array('d', [1.0, 2.0, 3.0])) + DenseVector([1.0, 2.0, 3.0]) """ return DenseVector(elements) @@ -518,23 +591,43 @@ class DenseMatrix(Matrix): """ def __init__(self, numRows, numCols, values): Matrix.__init__(self, numRows, numCols) + if isinstance(values, basestring): + values = np.frombuffer(values, dtype=np.float64) + elif not isinstance(values, np.ndarray): + values = np.array(values, dtype=np.float64) assert len(values) == numRows * numCols + if values.dtype != np.float64: + values.astype(np.float64) self.values = values def __reduce__(self): - return DenseMatrix, (self.numRows, self.numCols, self.values) + return DenseMatrix, (self.numRows, self.numCols, self.values.tostring()) def toArray(self): """ Return an numpy.ndarray - >>> arr = array.array('d', [float(i) for i in range(4)]) - >>> m = DenseMatrix(2, 2, arr) + >>> m = DenseMatrix(2, 2, range(4)) >>> m.toArray() array([[ 0., 2.], [ 1., 3.]]) """ - return np.reshape(self.values, (self.numRows, self.numCols), order='F') + return self.values.reshape((self.numRows, self.numCols), order='F') + + def __eq__(self, other): + return (isinstance(other, DenseMatrix) and + self.numRows == other.numRows and + self.numCols == other.numCols and + all(self.values == other.values)) + + +class Matrices(object): + @staticmethod + def dense(numRows, numCols, values): + """ + Create a DenseMatrix + """ + return DenseMatrix(numRows, numCols, values) def _test(): @@ -544,8 +637,4 @@ def _test(): exit(-1) if __name__ == "__main__": - # remove current path from list of search paths to avoid importing mllib.random - # for C{import random}, which is done in an external dependency of pyspark during doctests. - import sys - sys.path.pop(0) _test() diff --git a/python/pyspark/mllib/rand.py b/python/pyspark/mllib/rand.py new file mode 100644 index 0000000000000..cb4304f92152b --- /dev/null +++ b/python/pyspark/mllib/rand.py @@ -0,0 +1,223 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for random data generation. +""" + +from functools import wraps + +from pyspark.mllib.common import callMLlibFunc + + +__all__ = ['RandomRDDs', ] + + +def toArray(f): + @wraps(f) + def func(sc, *a, **kw): + rdd = f(sc, *a, **kw) + return rdd.map(lambda vec: vec.toArray()) + return func + + +class RandomRDDs(object): + """ + Generator methods for creating RDDs comprised of i.i.d samples from + some distribution. + """ + + @staticmethod + def uniformRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the + uniform distribution U(0.0, 1.0). + + To transform the distribution in the generated RDD from U(0.0, 1.0) + to U(a, b), use + C{RandomRDDs.uniformRDD(sc, n, p, seed)\ + .map(lambda v: a + (b - a) * v)} + + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`. + + >>> x = RandomRDDs.uniformRDD(sc, 100).collect() + >>> len(x) + 100 + >>> max(x) <= 1.0 and min(x) >= 0.0 + True + >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() + 4 + >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() + >>> parts == sc.defaultParallelism + True + """ + return callMLlibFunc("uniformRDD", sc._jsc, size, numPartitions, seed) + + @staticmethod + def normalRDD(sc, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the standard normal + distribution. + + To transform the distribution in the generated RDD from standard normal + to some other normal N(mean, sigma^2), use + C{RandomRDDs.normal(sc, n, p, seed)\ + .map(lambda v: mean + sigma * v)} + + :param sc: SparkContext used to create the RDD. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0). + + >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - 0.0) < 0.1 + True + >>> abs(stats.stdev() - 1.0) < 0.1 + True + """ + return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) + + @staticmethod + def poissonRDD(sc, mean, size, numPartitions=None, seed=None): + """ + Generates an RDD comprised of i.i.d. samples from the Poisson + distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param size: Size of the RDD. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of float comprised of i.i.d. samples ~ Pois(mean). + + >>> mean = 100.0 + >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L) + >>> stats = x.stats() + >>> stats.count() + 1000L + >>> abs(stats.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) + + @staticmethod + @toArray + def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the uniform distribution U(0.0, 1.0). + + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD. + :param seed: Seed for the RNG that generates the seed for the generator in each partition. + :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) + >>> mat.shape + (10, 10) + >>> mat.max() <= 1.0 and mat.min() >= 0.0 + True + >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() + 4 + """ + return callMLlibFunc("uniformVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) + + @staticmethod + @toArray + def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the standard normal distribution. + + :param sc: SparkContext used to create the RDD. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`). + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`. + + >>> import numpy as np + >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - 0.0) < 0.1 + True + >>> abs(mat.std() - 1.0) < 0.1 + True + """ + return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) + + @staticmethod + @toArray + def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): + """ + Generates an RDD comprised of vectors containing i.i.d. samples drawn + from the Poisson distribution with the input mean. + + :param sc: SparkContext used to create the RDD. + :param mean: Mean, or lambda, for the Poisson distribution. + :param numRows: Number of Vectors in the RDD. + :param numCols: Number of elements in each Vector. + :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`) + :param seed: Random seed (default: a random long integer). + :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean). + + >>> import numpy as np + >>> mean = 100.0 + >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) + >>> mat = np.mat(rdd.collect()) + >>> mat.shape + (100, 100) + >>> abs(mat.mean() - mean) < 0.5 + True + >>> from math import sqrt + >>> abs(mat.std() - sqrt(mean)) < 0.5 + True + """ + return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols, + numPartitions, seed) + + +def _test(): + import doctest + from pyspark.context import SparkContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py deleted file mode 100644 index a787e4dea2c55..0000000000000 --- a/python/pyspark/mllib/random.py +++ /dev/null @@ -1,200 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Python package for random data generation. -""" - -from functools import wraps - -from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer - - -__all__ = ['RandomRDDs', ] - - -def serialize(f): - @wraps(f) - def func(sc, *a, **kw): - jrdd = f(sc, *a, **kw) - return RDD(sc._jvm.PythonRDD.javaToPython(jrdd), sc, - BatchedSerializer(PickleSerializer(), 1024)) - return func - - -def toArray(f): - @wraps(f) - def func(sc, *a, **kw): - rdd = f(sc, *a, **kw) - return rdd.map(lambda vec: vec.toArray()) - return func - - -class RandomRDDs(object): - """ - Generator methods for creating RDDs comprised of i.i.d samples from - some distribution. - """ - - @staticmethod - @serialize - def uniformRDD(sc, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the - uniform distribution U(0.0, 1.0). - - To transform the distribution in the generated RDD from U(0.0, 1.0) - to U(a, b), use - C{RandomRDDs.uniformRDD(sc, n, p, seed)\ - .map(lambda v: a + (b - a) * v)} - - >>> x = RandomRDDs.uniformRDD(sc, 100).collect() - >>> len(x) - 100 - >>> max(x) <= 1.0 and min(x) >= 0.0 - True - >>> RandomRDDs.uniformRDD(sc, 100, 4).getNumPartitions() - 4 - >>> parts = RandomRDDs.uniformRDD(sc, 100, seed=4).getNumPartitions() - >>> parts == sc.defaultParallelism - True - """ - return sc._jvm.PythonMLLibAPI().uniformRDD(sc._jsc, size, numPartitions, seed) - - @staticmethod - @serialize - def normalRDD(sc, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the standard normal - distribution. - - To transform the distribution in the generated RDD from standard normal - to some other normal N(mean, sigma^2), use - C{RandomRDDs.normal(sc, n, p, seed)\ - .map(lambda v: mean + sigma * v)} - - >>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L) - >>> stats = x.stats() - >>> stats.count() - 1000L - >>> abs(stats.mean() - 0.0) < 0.1 - True - >>> abs(stats.stdev() - 1.0) < 0.1 - True - """ - return sc._jvm.PythonMLLibAPI().normalRDD(sc._jsc, size, numPartitions, seed) - - @staticmethod - @serialize - def poissonRDD(sc, mean, size, numPartitions=None, seed=None): - """ - Generates an RDD comprised of i.i.d. samples from the Poisson - distribution with the input mean. - - >>> mean = 100.0 - >>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=1L) - >>> stats = x.stats() - >>> stats.count() - 1000L - >>> abs(stats.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(stats.stdev() - sqrt(mean)) < 0.5 - True - """ - return sc._jvm.PythonMLLibAPI().poissonRDD(sc._jsc, mean, size, numPartitions, seed) - - @staticmethod - @toArray - @serialize - def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the uniform distribution U(0.0, 1.0). - - >>> import numpy as np - >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) - >>> mat.shape - (10, 10) - >>> mat.max() <= 1.0 and mat.min() >= 0.0 - True - >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() - 4 - """ - return sc._jvm.PythonMLLibAPI() \ - .uniformVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - - @staticmethod - @toArray - @serialize - def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the standard normal distribution. - - >>> import numpy as np - >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - 0.0) < 0.1 - True - >>> abs(mat.std() - 1.0) < 0.1 - True - """ - return sc._jvm.PythonMLLibAPI() \ - .normalVectorRDD(sc._jsc, numRows, numCols, numPartitions, seed) - - @staticmethod - @toArray - @serialize - def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None): - """ - Generates an RDD comprised of vectors containing i.i.d. samples drawn - from the Poisson distribution with the input mean. - - >>> import numpy as np - >>> mean = 100.0 - >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L) - >>> mat = np.mat(rdd.collect()) - >>> mat.shape - (100, 100) - >>> abs(mat.mean() - mean) < 0.5 - True - >>> from math import sqrt - >>> abs(mat.std() - sqrt(mean)) < 0.5 - True - """ - return sc._jvm.PythonMLLibAPI() \ - .poissonVectorRDD(sc._jsc, mean, numRows, numCols, numPartitions, seed) - - -def _test(): - import doctest - from pyspark.context import SparkContext - globs = globals().copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - globs['sc'] = SparkContext('local[2]', 'PythonTest', batchSize=2) - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() - if failure_count: - exit(-1) - - -if __name__ == "__main__": - _test() diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 59c1c5ff0ced0..97ec74eda0b71 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -15,27 +15,31 @@ # limitations under the License. # +from collections import namedtuple + from pyspark import SparkContext -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer from pyspark.rdd import RDD +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc -__all__ = ['MatrixFactorizationModel', 'ALS'] +__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating'] -class Rating(object): - def __init__(self, user, product, rating): - self.user = int(user) - self.product = int(product) - self.rating = float(rating) +class Rating(namedtuple("Rating", ["user", "product", "rating"])): + """ + Represents a (user, product, rating) tuple. - def __reduce__(self): - return Rating, (self.user, self.product, self.rating) + >>> r = Rating(1, 2, 5.0) + >>> (r.user, r.product, r.rating) + (1, 2, 5.0) + >>> (r[0], r[1], r[2]) + (1, 2, 5.0) + """ - def __repr__(self): - return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating) + def __reduce__(self): + return Rating, (int(self.user), int(self.product), float(self.rating)) -class MatrixFactorizationModel(object): +class MatrixFactorizationModel(JavaModelWrapper): """A matrix factorisation model trained by regularized alternating least-squares. @@ -44,43 +48,55 @@ class MatrixFactorizationModel(object): >>> r2 = (1, 2, 2.0) >>> r3 = (2, 1, 2.0) >>> ratings = sc.parallelize([r1, r2, r3]) - >>> model = ALS.trainImplicit(ratings, 1) - >>> model.predict(2,2) is not None - True + >>> model = ALS.trainImplicit(ratings, 1, seed=10) + >>> model.predict(2,2) + 0.4473... >>> testset = sc.parallelize([(1, 2), (1, 1)]) - >>> model = ALS.train(ratings, 1) - >>> model.predictAll(testset).count() == 2 + >>> model = ALS.train(ratings, 1, seed=10) + >>> model.predictAll(testset).collect() + [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)] + + >>> model = ALS.train(ratings, 4, seed=10) + >>> model.userFeatures().collect() + [(2, array('d', [...])), (1, array('d', [...]))] + + >>> first_user = model.userFeatures().take(1)[0] + >>> latents = first_user[1] + >>> len(latents) == 4 True - """ - def __init__(self, sc, java_model): - self._context = sc - self._java_model = java_model + >>> model.productFeatures().collect() + [(2, array('d', [...])), (1, array('d', [...]))] + + >>> first_product = model.productFeatures().take(1)[0] + >>> latents = first_product[1] + >>> len(latents) == 4 + True - def __del__(self): - self._context._gateway.detach(self._java_model) + >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10) + >>> model.predict(2,2) + 3.735... + >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10) + >>> model.predict(2,2) + 0.4473... + """ def predict(self, user, product): - return self._java_model.predict(user, product) + return self._java_model.predict(int(user), int(product)) def predictAll(self, user_product): assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)" first = user_product.first() - if isinstance(first, list): - user_product = user_product.map(tuple) - first = tuple(first) - assert type(first) is tuple and len(first) == 2, \ - "user_product should be RDD of (user, product)" - if any(isinstance(x, str) for x in first): - user_product = user_product.map(lambda (u, p): (int(x), int(p))) - first = tuple(map(int, first)) - assert all(type(x) is int for x in first), "user and product in user_product shoul be int" - sc = self._context - tuplerdd = sc._jvm.SerDe.asTupleRDD(user_product._to_java_object_rdd().rdd()) - jresult = self._java_model.predict(tuplerdd).toJavaRDD() - return RDD(sc._jvm.PythonRDD.javaToPython(jresult), sc, - AutoBatchedSerializer(PickleSerializer())) + assert len(first) == 2, "user_product should be RDD of (user, product)" + user_product = user_product.map(lambda (u, p): (int(u), int(p))) + return self.call("predict", user_product) + + def userFeatures(self): + return self.call("getUserFeatures") + + def productFeatures(self): + return self.call("getProductFeatures") class ALS(object): @@ -94,32 +110,28 @@ def _prepare(cls, ratings): ratings = ratings.map(lambda x: Rating(*x)) else: raise ValueError("rating should be RDD of Rating or tuple/list") - # serialize them by AutoBatchedSerializer before cache to reduce the - # objects overhead in JVM - cached = ratings._reserialize(AutoBatchedSerializer(PickleSerializer())).cache() - return cached._to_java_object_rdd() + return ratings @classmethod - def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): - sc = ratings.context - jrating = cls._prepare(ratings) - mod = sc._jvm.PythonMLLibAPI().trainALSModel(jrating, rank, iterations, lambda_, blocks) - return MatrixFactorizationModel(sc, mod) + def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False, + seed=None): + model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations, + lambda_, blocks, nonnegative, seed) + return MatrixFactorizationModel(model) @classmethod - def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): - sc = ratings.context - jrating = cls._prepare(ratings) - mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel( - jrating, rank, iterations, lambda_, blocks, alpha) - return MatrixFactorizationModel(sc, mod) + def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01, + nonnegative=False, seed=None): + model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank, + iterations, lambda_, blocks, alpha, nonnegative, seed) + return MatrixFactorizationModel(model) def _test(): import doctest import pyspark.mllib.recommendation globs = pyspark.mllib.recommendation.__dict__.copy() - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index cbdbc09858013..210060140fd91 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,11 +18,10 @@ import numpy as np from numpy import array -from pyspark import SparkContext +from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.linalg import SparseVector, _convert_to_vector -from pyspark.serializers import PickleSerializer, AutoBatchedSerializer -__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel' +__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel', 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] @@ -31,13 +30,13 @@ class LabeledPoint(object): """ The features and labels of a data point. - @param label: Label for this data point. - @param features: Vector of features for this point (NumPy array, list, + :param label: Label for this data point. + :param features: Vector of features for this point (NumPy array, list, pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix) """ def __init__(self, label, features): - self.label = label + self.label = float(label) self.features = _convert_to_vector(features) def __reduce__(self): @@ -47,7 +46,7 @@ def __str__(self): return "(" + ",".join((str(self.label), str(self.features))) + ")" def __repr__(self): - return "LabeledPoint(" + ",".join((repr(self.label), repr(self.features))) + ")" + return "LabeledPoint(%s, %s)" % (self.label, self.features) class LinearModel(object): @@ -56,7 +55,7 @@ class LinearModel(object): def __init__(self, weights, intercept): self._coeff = _convert_to_vector(weights) - self._intercept = intercept + self._intercept = float(intercept) @property def weights(self): @@ -66,6 +65,9 @@ def weights(self): def intercept(self): return self._intercept + def __repr__(self): + return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept) + class LinearRegressionModelBase(LinearModel): @@ -83,6 +85,7 @@ def predict(self, x): Predict the value of the dependent variable given a vector x containing values for the independent variables. """ + x = _convert_to_vector(x) return self.weights.dot(x) + self.intercept @@ -121,54 +124,52 @@ class LinearRegressionModel(LinearRegressionModelBase): # train_func should take two parameters, namely data and initial_weights, and # return the result of a call to the appropriate JVM stub. # _regression_train_wrapper is responsible for setup and error checking. -def _regression_train_wrapper(sc, train_func, modelClass, data, initial_weights): +def _regression_train_wrapper(train_func, modelClass, data, initial_weights): + first = data.first() + if not isinstance(first, LabeledPoint): + raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first) initial_weights = initial_weights or [0.0] * len(data.first().features) - ser = PickleSerializer() - initial_bytes = bytearray(ser.dumps(_convert_to_vector(initial_weights))) - # use AutoBatchedSerializer before cache to reduce the memory - # overhead in JVM - cached = data._reserialize(AutoBatchedSerializer(ser)).cache() - ans = train_func(cached._to_java_object_rdd(), initial_bytes) - assert len(ans) == 2, "JVM call result had unexpected length" - weights = ser.loads(str(ans[0])) - return modelClass(weights, ans[1]) + weights, intercept = train_func(data, _convert_to_vector(initial_weights)) + return modelClass(weights, intercept) class LinearRegressionWithSGD(object): @classmethod def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, - initialWeights=None, regParam=1.0, regType="none", intercept=False): + initialWeights=None, regParam=0.0, regType=None, intercept=False): """ Train a linear regression model on the given data. - @param data: The training data. - @param iterations: The number of iterations (default: 100). - @param step: The step parameter used in SGD + :param data: The training data. + :param iterations: The number of iterations (default: 100). + :param step: The step parameter used in SGD (default: 1.0). - @param miniBatchFraction: Fraction of data to be used for each SGD + :param miniBatchFraction: Fraction of data to be used for each SGD iteration. - @param initialWeights: The initial weights (default: None). - @param regParam: The regularizer parameter (default: 1.0). - @param regType: The type of regularizer used for training + :param initialWeights: The initial weights (default: None). + :param regParam: The regularizer parameter (default: 0.0). + :param regType: The type of regularizer used for training our model. - Allowed values: "l1" for using L1Updater, - "l2" for using - SquaredL2Updater, - "none" for no regularizer. - (default: "none") + + :Allowed values: + - "l1" for using L1 regularization (lasso), + - "l2" for using L2 regularization (ridge), + - None for no regularization + + (default: None) + @param intercept: Boolean parameter which indicates the use or not of the augmented representation for training data (i.e. whether bias features are activated or not). """ - sc = data.context - - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainLinearRegressionModelWithSGD( - jrdd, iterations, step, miniBatchFraction, i, regParam, regType, intercept) + def train(rdd, i): + return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), + float(step), float(miniBatchFraction), i, float(regParam), + regType, bool(intercept)) - return _regression_train_wrapper(sc, train, LinearRegressionModel, data, initialWeights) + return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights) class LassoModel(LinearRegressionModelBase): @@ -207,15 +208,14 @@ class LassoModel(LinearRegressionModelBase): class LassoWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, + def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None): """Train a Lasso regression model on the given data.""" - sc = data.context + def train(rdd, i): + return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i) - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainLassoModelWithSGD( - jrdd, iterations, step, regParam, miniBatchFraction, i) - return _regression_train_wrapper(sc, train, LassoModel, data, initialWeights) + return _regression_train_wrapper(train, LassoModel, data, initialWeights) class RidgeRegressionModel(LinearRegressionModelBase): @@ -254,21 +254,21 @@ class RidgeRegressionModel(LinearRegressionModelBase): class RidgeRegressionWithSGD(object): @classmethod - def train(cls, data, iterations=100, step=1.0, regParam=1.0, + def train(cls, data, iterations=100, step=1.0, regParam=0.01, miniBatchFraction=1.0, initialWeights=None): """Train a ridge regression model on the given data.""" - sc = data.context - - def train(jrdd, i): - return sc._jvm.PythonMLLibAPI().trainRidgeModelWithSGD( - jrdd, iterations, step, regParam, miniBatchFraction, i) + def train(rdd, i): + return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), + float(regParam), float(miniBatchFraction), i) - return _regression_train_wrapper(sc, train, RidgeRegressionModel, data, initialWeights) + return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) def _test(): import doctest - globs = globals().copy() + from pyspark import SparkContext + import pyspark.mllib.regression + globs = pyspark.mllib.regression.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index b9de0909a6fb1..1980f5b03f430 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -19,65 +19,86 @@ Python package for statistical functions in MLlib. """ -from functools import wraps +from pyspark import RDD +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import Matrix, _convert_to_vector +from pyspark.mllib.regression import LabeledPoint -from pyspark import PickleSerializer +__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] -__all__ = ['MultivariateStatisticalSummary', 'Statistics'] - -def serialize(f): - ser = PickleSerializer() - - @wraps(f) - def func(self): - jvec = f(self) - bytes = self._sc._jvm.SerDe.dumps(jvec) - return ser.loads(str(bytes)).toArray() - - return func - - -class MultivariateStatisticalSummary(object): +class MultivariateStatisticalSummary(JavaModelWrapper): """ Trait for multivariate statistical summary of a data matrix. """ - def __init__(self, sc, java_summary): - """ - :param sc: Spark context - :param java_summary: Handle to Java summary object - """ - self._sc = sc - self._java_summary = java_summary - - def __del__(self): - self._sc._gateway.detach(self._java_summary) - - @serialize def mean(self): - return self._java_summary.mean() + return self.call("mean").toArray() - @serialize def variance(self): - return self._java_summary.variance() + return self.call("variance").toArray() def count(self): - return self._java_summary.count() + return self.call("count") - @serialize def numNonzeros(self): - return self._java_summary.numNonzeros() + return self.call("numNonzeros").toArray() - @serialize def max(self): - return self._java_summary.max() + return self.call("max").toArray() - @serialize def min(self): - return self._java_summary.min() + return self.call("min").toArray() + + +class ChiSqTestResult(JavaModelWrapper): + """ + :: Experimental :: + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() class Statistics(object): @@ -87,6 +108,11 @@ def colStats(rdd): """ Computes column-wise summary statistics for the input RDD[Vector]. + :param rdd: an RDD[Vector] for which column-wise summary statistics + are to be computed. + :return: :class:`MultivariateStatisticalSummary` object containing + column-wise summary statistics. + >>> from pyspark.mllib.linalg import Vectors >>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]), ... Vectors.dense([4, 5, 0, 3]), @@ -105,10 +131,8 @@ def colStats(rdd): >>> cStats.min() array([ 2., 0., 0., -2.]) """ - sc = rdd.ctx - jrdd = rdd._to_java_object_rdd() - cStats = sc._jvm.PythonMLLibAPI().colStats(jrdd) - return MultivariateStatisticalSummary(sc, cStats) + cStats = callMLlibFunc("colStats", rdd.map(_convert_to_vector)) + return MultivariateStatisticalSummary(cStats) @staticmethod def corr(x, y=None, method=None): @@ -122,6 +146,13 @@ def corr(x, y=None, method=None): to specify the method to be used for single RDD inout. If two RDDs of floats are passed in, a single float is returned. + :param x: an RDD of vector for which the correlation matrix is to be computed, + or an RDD of float of the same cardinality as y when y is specified. + :param y: an RDD of float of the same cardinality as x. + :param method: String specifying the method to use for computing correlation. + Supported: `pearson` (default), `spearman` + :return: Correlation matrix comparing columns in x. + >>> x = sc.parallelize([1.0, 0.0, -2.0], 2) >>> y = sc.parallelize([4.0, 5.0, 3.0], 2) >>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2) @@ -155,22 +186,101 @@ def corr(x, y=None, method=None): ... except TypeError: ... pass """ - sc = x.ctx # Check inputs to determine whether a single value or a matrix is needed for output. # Since it's legal for users to use the method name as the second argument, we need to # check if y is used to specify the method name instead. if type(y) == str: raise TypeError("Use 'method=' to specify method name.") - jx = x._to_java_object_rdd() if not y: - resultMat = sc._jvm.PythonMLLibAPI().corr(jx, method) - bytes = sc._jvm.SerDe.dumps(resultMat) - ser = PickleSerializer() - return ser.loads(str(bytes)).toArray() + return callMLlibFunc("corr", x.map(_convert_to_vector), method).toArray() + else: + return callMLlibFunc("corr", x.map(float), y.map(float), method) + + @staticmethod + def chiSqTest(observed, expected=None): + """ + :: Experimental :: + + If `observed` is Vector, conduct Pearson's chi-squared goodness + of fit test of the observed data against the expected distribution, + or againt the uniform distribution (by default), with each category + having an expected frequency of `1 / len(observed)`. + (Note: `observed` cannot contain negative values) + + If `observed` is matrix, conduct Pearson's independence test on the + input contingency matrix, which cannot contain negative entries or + columns or rows that sum up to 0. + + If `observed` is an RDD of LabeledPoint, conduct Pearson's independence + test for every feature against the label across the input RDD. + For each feature, the (feature, label) pairs are converted into a + contingency matrix for which the chi-squared statistic is computed. + All label and feature values must be categorical. + + :param observed: it could be a vector containing the observed categorical + counts/relative frequencies, or the contingency matrix + (containing either counts or relative frequencies), + or an RDD of LabeledPoint containing the labeled dataset + with categorical features. Real-valued features will be + treated as categorical for each distinct value. + :param expected: Vector containing the expected categorical counts/relative + frequencies. `expected` is rescaled if the `expected` sum + differs from the `observed` sum. + :return: ChiSquaredTest object containing the test statistic, degrees + of freedom, p-value, the method used, and the null hypothesis. + + >>> from pyspark.mllib.linalg import Vectors, Matrices + >>> observed = Vectors.dense([4, 6, 5]) + >>> pearson = Statistics.chiSqTest(observed) + >>> print pearson.statistic + 0.4 + >>> pearson.degreesOfFreedom + 2 + >>> print round(pearson.pValue, 4) + 0.8187 + >>> pearson.method + u'pearson' + >>> pearson.nullHypothesis + u'observed follows the same distribution as expected.' + + >>> observed = Vectors.dense([21, 38, 43, 80]) + >>> expected = Vectors.dense([3, 5, 7, 20]) + >>> pearson = Statistics.chiSqTest(observed, expected) + >>> print round(pearson.pValue, 4) + 0.0027 + + >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0] + >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data)) + >>> print round(chi.statistic, 4) + 21.9958 + + >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])), + ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])), + ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 30.0])), + ... LabeledPoint(0.0, Vectors.dense([3.5, 40.0])), + ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),] + >>> rdd = sc.parallelize(data, 4) + >>> chi = Statistics.chiSqTest(rdd) + >>> print chi[0].statistic + 0.75 + >>> print chi[1].statistic + 1.5 + """ + if isinstance(observed, RDD): + if not isinstance(observed.first(), LabeledPoint): + raise ValueError("observed should be an RDD of LabeledPoint") + jmodels = callMLlibFunc("chiSqTest", observed) + return [ChiSqTestResult(m) for m in jmodels] + + if isinstance(observed, Matrix): + jmodel = callMLlibFunc("chiSqTest", observed) else: - jy = y._to_java_object_rdd() - return sc._jvm.PythonMLLibAPI().corr(jx, jy, method) + if expected and len(expected) != len(observed): + raise ValueError("`expected` should have same length with `observed`") + jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected) + return ChiSqTestResult(jmodel) def _test(): diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f72e88ba6e2ba..8332f8e061f48 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -25,15 +25,22 @@ from numpy import array, array_equal if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest -from pyspark.serializers import PickleSerializer -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ + DenseMatrix from pyspark.mllib.regression import LabeledPoint -from pyspark.tests import PySparkTestCase - +from pyspark.mllib.random import RandomRDDs +from pyspark.mllib.stat import Statistics +from pyspark.serializers import PickleSerializer +from pyspark.sql import SQLContext +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase _have_scipy = False try: @@ -56,6 +63,7 @@ def _squared_distance(a, b): class VectorTests(PySparkTestCase): def _test_serialize(self, v): + self.assertEqual(v, ser.loads(ser.dumps(v))) jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v))) nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec))) self.assertEqual(v, nv) @@ -69,6 +77,8 @@ def test_serialize(self): self._test_serialize(DenseVector(array([1., 2., 3., 4.]))) self._test_serialize(DenseVector(pyarray.array('d', range(10)))) self._test_serialize(SparseVector(4, {1: 1, 3: 2})) + self._test_serialize(SparseVector(3, {})) + self._test_serialize(DenseMatrix(2, 3, range(6))) def test_dot(self): sv = SparseVector(4, {1: 1, 3: 2}) @@ -198,6 +208,56 @@ def test_regression(self): self.assertTrue(dt_model.predict(features[3]) > 0) +class StatTests(PySparkTestCase): + # SPARK-4023 + def test_col_with_different_rdds(self): + # numpy + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(1000, summary.count()) + # array + data = self.sc.parallelize([range(10)] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + # array + data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + + +class VectorUDTTests(PySparkTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + srdd = sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = srdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index afdcdbdf3ae01..46e253991aa56 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -15,60 +15,38 @@ # limitations under the License. # -from py4j.java_collections import MapConverter +from __future__ import absolute_import + +import random from pyspark import SparkContext, RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer -from pyspark.mllib.linalg import Vector, _convert_to_vector +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper +from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint -__all__ = ['DecisionTreeModel', 'DecisionTree'] +__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', 'RandomForest'] -class DecisionTreeModel(object): +class DecisionTreeModel(JavaModelWrapper): """ A decision tree model for classification or regression. EXPERIMENTAL: This is an experimental API. - It will probably be modified for Spark v1.2. + It will probably be modified in future. """ - - def __init__(self, sc, java_model): - """ - :param sc: Spark context - :param java_model: Handle to Java model object - """ - self._sc = sc - self._java_model = java_model - - def __del__(self): - self._sc._gateway.detach(self._java_model) - def predict(self, x): """ Predict the label of one or more examples. + :param x: Data point (feature vector), or an RDD of data points (feature vectors). """ - SerDe = self._sc._jvm.SerDe - ser = PickleSerializer() if isinstance(x, RDD): - # Bulk prediction - first = x.take(1) - if not first: - return self._sc.parallelize([]) - if not isinstance(first[0], Vector): - x = x.map(_convert_to_vector) - jPred = self._java_model.predict(x._to_java_object_rdd()).toJavaRDD() - jpyrdd = self._sc._jvm.PythonRDD.javaToPython(jPred) - return RDD(jpyrdd, self._sc, BatchedSerializer(ser, 1024)) + return self.call("predict", x.map(_convert_to_vector)) else: - # Assume x is a single data point. - bytes = bytearray(ser.dumps(_convert_to_vector(x))) - vec = self._sc._jvm.SerDe.loads(bytes) - return self._java_model.predict(vec) + return self.call("predict", _convert_to_vector(x)) def numNodes(self): return self._java_model.numNodes() @@ -77,42 +55,34 @@ def depth(self): return self._java_model.depth() def __repr__(self): - """ Print summary of model. """ + """ summary of model. """ return self._java_model.toString() def toDebugString(self): - """ Print full model. """ + """ full model. """ return self._java_model.toDebugString() class DecisionTree(object): """ - Learning algorithm for a decision tree model - for classification or regression. + Learning algorithm for a decision tree model for classification or regression. EXPERIMENTAL: This is an experimental API. - It will probably be modified for Spark v1.2. - + It will probably be modified in future. """ - @staticmethod - def _train(data, type, numClasses, categoricalFeaturesInfo, - impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, - minInfoGain=0.0): + @classmethod + def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32, + minInstancesPerNode=1, minInfoGain=0.0): first = data.first() assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" - sc = data.context - jrdd = data._to_java_object_rdd() - cfiMap = MapConverter().convert(categoricalFeaturesInfo, - sc._gateway._gateway_client) - model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( - jrdd, type, numClasses, cfiMap, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) - return DecisionTreeModel(sc, model) - - @staticmethod - def trainClassifier(data, numClasses, categoricalFeaturesInfo, + model = callMLlibFunc("trainDecisionTreeModel", data, type, numClasses, features, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + return DecisionTreeModel(model) + + @classmethod + def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ @@ -130,8 +100,8 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child nodes to create - the parent split + :param minInstancesPerNode: Min number of instances required at child + nodes to create the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel @@ -152,20 +122,23 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo, DecisionTreeModel classifier of depth 1 with 3 nodes >>> print model.toDebugString(), # it already has newline DecisionTreeModel classifier of depth 1 with 3 nodes - If (feature 0 <= 0.5) + If (feature 0 <= 0.0) Predict: 0.0 - Else (feature 0 > 0.5) + Else (feature 0 > 0.0) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 - True - >>> model.predict(array([0.0])) == 0 - True + >>> model.predict(array([1.0])) + 1.0 + >>> model.predict(array([0.0])) + 0.0 + >>> rdd = sc.parallelize([[1.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ - return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + return cls._train(data, "classification", numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) - @staticmethod - def trainRegressor(data, categoricalFeaturesInfo, + @classmethod + def trainRegressor(cls, data, categoricalFeaturesInfo, impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0): """ @@ -182,14 +155,13 @@ def trainRegressor(data, categoricalFeaturesInfo, E.g., depth 0 means 1 leaf node. Depth 1 means 1 internal node + 2 leaf nodes. :param maxBins: Number of bins used for finding splits at each node. - :param minInstancesPerNode: Min number of instances required at child nodes to create - the parent split + :param minInstancesPerNode: Min number of instances required at child + nodes to create the parent split :param minInfoGain: Min info gain required to create a split :return: DecisionTreeModel Example usage: - >>> from numpy import array >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree >>> from pyspark.mllib.linalg import SparseVector @@ -202,17 +174,213 @@ def trainRegressor(data, categoricalFeaturesInfo, ... ] >>> >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {}) - >>> model.predict(array([0.0, 1.0])) == 1 - True - >>> model.predict(array([0.0, 0.0])) == 0 - True - >>> model.predict(SparseVector(2, {1: 1.0})) == 1 - True - >>> model.predict(SparseVector(2, {1: 0.0})) == 0 - True - """ - return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo, - impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {1: 0.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "regression", 0, categoricalFeaturesInfo, + impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) + + +class RandomForestModel(JavaModelWrapper): + """ + Represents a random forest model. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified in future. + """ + def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + if isinstance(x, RDD): + return self.call("predict", x.map(_convert_to_vector)) + + else: + return self.call("predict", _convert_to_vector(x)) + + def numTrees(self): + """ + Get number of trees in forest. + """ + return self.call("numTrees") + + def totalNumNodes(self): + """ + Get total number of nodes, summed over all trees in the forest. + """ + return self.call("totalNumNodes") + + def __repr__(self): + """ Summary of model """ + return self._java_model.toString() + + def toDebugString(self): + """ Full model """ + return self._java_model.toDebugString() + + +class RandomForest(object): + """ + Learning algorithm for a random forest model for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified in future. + """ + + supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") + + @classmethod + def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees, + featureSubsetStrategy, impurity, maxDepth, maxBins, seed): + first = data.first() + assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" + if featureSubsetStrategy not in cls.supportedFeatureSubsetStrategies: + raise ValueError("unsupported featureSubsetStrategy: %s" % featureSubsetStrategy) + if seed is None: + seed = random.randint(0, 1 << 30) + model = callMLlibFunc("trainRandomForestModel", data, algo, numClasses, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, + maxDepth, maxBins, seed) + return RandomForestModel(model) + + @classmethod + def trainClassifier(cls, data, numClassesForClassification, categoricalFeaturesInfo, numTrees, + featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32, + seed=None): + """ + Method to train a decision tree model for binary or multiclass + classification. + + :param data: Training dataset: RDD of LabeledPoint. Labels should take + values {0, 1, ..., numClasses-1}. + :param numClassesForClassification: number of classes for classification. + :param categoricalFeaturesInfo: Map storing arity of categorical features. + E.g., an entry (n -> k) indicates that feature n is categorical + with k categories indexed from 0: {0, 1, ..., k-1}. + :param numTrees: Number of trees in the random forest. + :param featureSubsetStrategy: Number of features to consider for splits at + each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "sqrt". + :param impurity: Criterion used for information gain calculation. + Supported values: "gini" (recommended) or "entropy". + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node; + depth 1 means 1 internal node + 2 leaf nodes. (default: 4) + :param maxBins: maximum number of bins used for splitting features + (default: 100) + :param seed: Random seed for bootstrapping and choosing feature subsets. + :return: RandomForestModel that can be used for prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import RandomForest + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(0.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> model = RandomForest.trainClassifier(sc.parallelize(data), 2, {}, 3, seed=42) + >>> model.numTrees() + 3 + >>> model.totalNumNodes() + 7 + >>> print model, + TreeEnsembleModel classifier with 3 trees + >>> print model.toDebugString(), + TreeEnsembleModel classifier with 3 trees + + Tree 0: + Predict: 1.0 + Tree 1: + If (feature 0 <= 1.0) + Predict: 0.0 + Else (feature 0 > 1.0) + Predict: 1.0 + Tree 2: + If (feature 0 <= 1.0) + Predict: 0.0 + Else (feature 0 > 1.0) + Predict: 1.0 + >>> model.predict([2.0]) + 1.0 + >>> model.predict([0.0]) + 0.0 + >>> rdd = sc.parallelize([[3.0], [1.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "classification", numClassesForClassification, + categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, + maxDepth, maxBins, seed) + + @classmethod + def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto", + impurity="variance", maxDepth=4, maxBins=32, seed=None): + """ + Method to train a decision tree model for regression. + + :param data: Training dataset: RDD of LabeledPoint. Labels are + real numbers. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param numTrees: Number of trees in the random forest. + :param featureSubsetStrategy: Number of features to consider for + splits at each node. + Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + If "auto" is set, this parameter is set based on numTrees: + if numTrees == 1, set to "all"; + if numTrees > 1 (forest) set to "onethird" for regression. + :param impurity: Criterion used for information gain calculation. + Supported values: "variance". + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 + leaf node; depth 1 means 1 internal node + 2 leaf nodes. + (default: 4) + :param maxBins: maximum number of bins used for splitting features + (default: 100) + :param seed: Random seed for bootstrapping and choosing feature subsets. + :return: RandomForestModel that can be used for prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import RandomForest + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = RandomForest.trainRegressor(sc.parallelize(sparse_data), {}, 2, seed=42) + >>> model.numTrees() + 2 + >>> model.totalNumNodes() + 4 + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {0: 1.0})) + 0.5 + >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.5] + """ + return cls._train(data, "regression", 0, categoricalFeaturesInfo, numTrees, + featureSubsetStrategy, impurity, maxDepth, maxBins, seed) def _test(): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 8233d4e81f1ca..4ed978b45409c 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -18,8 +18,7 @@ import numpy as np import warnings -from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.mllib.common import callMLlibFunc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -77,10 +76,10 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None method parses each line into a LabeledPoint, where the feature indices are converted to zero-based. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param numFeatures: number of features, which will be determined + :param numFeatures: number of features, which will be determined from the input data if a nonpositive value is given. This is useful when the dataset is already split into multiple files and you @@ -88,7 +87,7 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None, multiclass=None features may not present in certain files, which leads to inconsistent feature dimensions. - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile @@ -126,8 +125,8 @@ def saveAsLibSVMFile(data, dir): """ Save labeled data in LIBSVM format. - @param data: an RDD of LabeledPoint to be saved - @param dir: directory to save the data + :param data: an RDD of LabeledPoint to be saved + :param dir: directory to save the data >>> from tempfile import NamedTemporaryFile >>> from fileinput import input @@ -149,10 +148,10 @@ def loadLabeledPoints(sc, path, minPartitions=None): """ Load labeled points saved using RDD.saveAsTextFile. - @param sc: Spark context - @param path: file or directory path in any Hadoop-supported file + :param sc: Spark context + :param path: file or directory path in any Hadoop-supported file system URI - @param minPartitions: min number of partitions + :param minPartitions: min number of partitions @return: labeled data stored as an RDD of LabeledPoint >>> from tempfile import NamedTemporaryFile @@ -162,20 +161,11 @@ def loadLabeledPoints(sc, path, minPartitions=None): >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.close() >>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name) - >>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect() - >>> type(loaded[0]) == LabeledPoint - True - >>> print examples[0] - (1.1,(3,[0,2],[-1.23,4.56e-07])) - >>> type(examples[1]) == LabeledPoint - True - >>> print examples[1] - (0.0,[1.01,2.02,3.03]) + >>> MLUtils.loadLabeledPoints(sc, tempFile.name).collect() + [LabeledPoint(1.1, (3,[0,2],[-1.23,4.56e-07])), LabeledPoint(0.0, [1.01,2.02,3.03])] """ minPartitions = minPartitions or min(sc.defaultParallelism, 2) - jrdd = sc._jvm.PythonMLLibAPI().loadLabeledPoints(sc._jsc, path, minPartitions) - jpyrdd = sc._jvm.PythonRDD.javaToPython(jrdd) - return RDD(jpyrdd, sc, BatchedSerializer(PickleSerializer())) + return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) def _test(): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index dc6497772e502..57754776faaa2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -28,7 +28,7 @@ import warnings import heapq import bisect -from random import Random +import random from math import sqrt, log, isinf, isnan from pyspark.accumulators import PStatsParam @@ -38,7 +38,7 @@ from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter -from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler +from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ @@ -120,7 +120,7 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx, jrdd_deserializer): + def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False @@ -129,12 +129,8 @@ def __init__(self, jrdd, ctx, jrdd_deserializer): self._id = jrdd.id() self._partitionFunc = None - def _toPickleSerialization(self): - if (self._jrdd_deserializer == PickleSerializer() or - self._jrdd_deserializer == BatchedSerializer(PickleSerializer())): - return self - else: - return self._reserialize(BatchedSerializer(PickleSerializer(), 10)) + def _pickled(self): + return self._reserialize(AutoBatchedSerializer(PickleSerializer())) def id(self): """ @@ -314,20 +310,43 @@ def distinct(self, numPartitions=None): def sample(self, withReplacement, fraction, seed=None): """ - Return a sampled subset of this RDD (relies on numpy and falls back - on default random generator if numpy is unavailable). + Return a sampled subset of this RDD. - >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP - [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98] + >>> rdd = sc.parallelize(range(100), 4) + >>> rdd.sample(False, 0.1, 81).count() + 10 """ assert fraction >= 0.0, "Negative fraction value: %s" % fraction return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True) + def randomSplit(self, weights, seed=None): + """ + Randomly splits this RDD with the provided weights. + + :param weights: weights for splits, will be normalized if they don't sum to 1 + :param seed: random seed + :return: split RDDs in a list + + >>> rdd = sc.parallelize(range(5), 1) + >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17) + >>> rdd1.collect() + [1, 3] + >>> rdd2.collect() + [0, 2, 4] + """ + s = float(sum(weights)) + cweights = [0.0] + for w in weights: + cweights.append(cweights[-1] + w / s) + if seed is None: + seed = random.randint(0, 2 ** 32 - 1) + return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True) + for lb, ub in zip(cweights, cweights[1:])] + # this is ported from scala/spark/RDD.scala def takeSample(self, withReplacement, num, seed=None): """ - Return a fixed-size sampled subset of this RDD (currently requires - numpy). + Return a fixed-size sampled subset of this RDD. >>> rdd = sc.parallelize(range(0, 10)) >>> len(rdd.takeSample(True, 20, 1)) @@ -348,7 +367,7 @@ def takeSample(self, withReplacement, num, seed=None): if initialCount == 0: return [] - rand = Random(seed) + rand = random.Random(seed) if (not withReplacement) and num >= initialCount: # shuffle current RDD and return @@ -449,12 +468,11 @@ def intersection(self, other): def _reserialize(self, serializer=None): serializer = serializer or self.ctx.serializer - if self._jrdd_deserializer == serializer: - return self - else: - converted = self.map(lambda x: x, preservesPartitioning=True) - converted._jrdd_deserializer = serializer - return converted + if self._jrdd_deserializer != serializer: + if not isinstance(self, PipelinedRDD): + self = self.map(lambda x: x, preservesPartitioning=True) + self._jrdd_deserializer = serializer + return self def __add__(self, other): """ @@ -529,6 +547,8 @@ def sortPartition(iterator): # the key-space into bins such that the bins have roughly the same # number of (key, value) pairs falling into them rddSize = self.count() + if not rddSize: + return self # empty RDD maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner fraction = min(maxSampleSize / max(rddSize, 1), 1.0) samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect() @@ -752,7 +772,7 @@ def max(self, key=None): """ Find the maximum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([1.0, 5.0, 43.0, 10.0]) >>> rdd.max() @@ -768,7 +788,7 @@ def min(self, key=None): """ Find the minimum item in this RDD. - @param key: A function used to generate key for comparing + :param key: A function used to generate key for comparing >>> rdd = sc.parallelize([2.0, 5.0, 43.0, 10.0]) >>> rdd.min() @@ -1070,10 +1090,13 @@ def take(self, num): # If we didn't find any rows after the previous iteration, # quadruple and retry. Otherwise, interpolate the number of # partitions we need to try, but overestimate it by 50%. + # We also cap the estimation in the end. if len(items) == 0: numPartsToTry = partsScanned * 4 else: - numPartsToTry = int(1.5 * num * partsScanned / len(items)) + # the first paramter of max is >=1 whenever partsScanned >= 2 + numPartsToTry = int(1.5 * num * partsScanned / len(items)) - partsScanned + numPartsToTry = min(max(numPartsToTry, 1), partsScanned * 4) left = num - len(items) @@ -1115,14 +1138,13 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, True) def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1135,21 +1157,20 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: Hadoop job configuration, passed in as a dict (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf) @@ -1161,14 +1182,13 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None): converted for output using either user specified converters or, by default, L{org.apache.spark.api.python.JavaToWritableConverter}. - @param conf: Hadoop job configuration, passed in as a dict - @param keyConverter: (None by default) - @param valueConverter: (None by default) + :param conf: Hadoop job configuration, passed in as a dict + :param keyConverter: (None by default) + :param valueConverter: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf, keyConverter, valueConverter, False) def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None, @@ -1182,22 +1202,21 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No C{conf} is applied on top of the base Hadoop conf associated with the SparkContext of this RDD to create a merged Hadoop MapReduce job configuration for saving the data. - @param path: path to Hadoop file - @param outputFormatClass: fully qualified classname of Hadoop OutputFormat + :param path: path to Hadoop file + :param outputFormatClass: fully qualified classname of Hadoop OutputFormat (e.g. "org.apache.hadoop.mapred.SequenceFileOutputFormat") - @param keyClass: fully qualified classname of key Writable class + :param keyClass: fully qualified classname of key Writable class (e.g. "org.apache.hadoop.io.IntWritable", None by default) - @param valueClass: fully qualified classname of value Writable class + :param valueClass: fully qualified classname of value Writable class (e.g. "org.apache.hadoop.io.Text", None by default) - @param keyConverter: (None by default) - @param valueConverter: (None by default) - @param conf: (None by default) - @param compressionCodecClass: (None by default) + :param keyConverter: (None by default) + :param valueConverter: (None by default) + :param conf: (None by default) + :param compressionCodecClass: (None by default) """ jconf = self.ctx._dictToJavaMap(conf) - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, True, path, outputFormatClass, keyClass, valueClass, keyConverter, valueConverter, @@ -1208,15 +1227,15 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None): Output a Python RDD of key-value pairs (of form C{RDD[(K, V)]}) to any Hadoop file system, using the L{org.apache.hadoop.io.Writable} types that we convert from the RDD's key and value types. The mechanism is as follows: + 1. Pyrolite is used to convert pickled Python RDD into RDD of Java objects. 2. Keys and values of this Java RDD are converted to Writables and written out. - @param path: path to sequence file - @param compressionCodecClass: (None by default) + :param path: path to sequence file + :param compressionCodecClass: (None by default) """ - pickledRDD = self._toPickleSerialization() - batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer) - self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched, + pickledRDD = self._pickled() + self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, True, path, compressionCodecClass) def saveAsPickleFile(self, path, batchSize=10): @@ -1231,8 +1250,11 @@ def saveAsPickleFile(self, path, batchSize=10): >>> sorted(sc.pickleFile(tmpFile.name, 5).collect()) [1, 2, 'rdd', 'spark'] """ - self._reserialize(BatchedSerializer(PickleSerializer(), - batchSize))._jrdd.saveAsObjectFile(path) + if batchSize == 0: + ser = AutoBatchedSerializer(PickleSerializer()) + else: + ser = BatchedSerializer(PickleSerializer(), batchSize) + self._reserialize(ser)._jrdd.saveAsObjectFile(path) def saveAsTextFile(self, path): """ @@ -1773,13 +1795,10 @@ def zip(self, other): >>> x.zip(y).collect() [(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)] """ - if self.getNumPartitions() != other.getNumPartitions(): - raise ValueError("Can only zip with RDD which has the same number of partitions") - def get_batch_size(ser): if isinstance(ser, BatchedSerializer): return ser.batchSize - return 0 + return 1 def batch_as(rdd, batchSize): ser = rdd._jrdd_deserializer @@ -1789,12 +1808,16 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: - # use the greatest batchSize to batch the other one. - if my_batch > other_batch: - other = batch_as(other, my_batch) - else: - self = batch_as(self, other_batch) + # use the smallest batchSize for both of them + batchSize = min(my_batch, other_batch) + if batchSize <= 0: + # auto batched or unlimited + batchSize = 100 + other = batch_as(other, batchSize) + self = batch_as(self, batchSize) + + if self.getNumPartitions() != other.getNumPartitions(): + raise ValueError("Can only zip with RDD which has the same number of partitions") # There will be an Exception in JVM if there are different number # of items in each partitions. @@ -1863,11 +1886,11 @@ def setName(self, name): Assign a name to this RDD. >>> rdd1 = sc.parallelize([1,2]) - >>> rdd1.setName('RDD1') - >>> rdd1.name() + >>> rdd1.setName('RDD1').name() 'RDD1' """ self._jrdd.setName(name) + return self def toDebugString(self): """ @@ -1933,25 +1956,14 @@ def lookup(self, key): return values.collect() - def _is_pickled(self): - """ Return this RDD is serialized by Pickle or not. """ - der = self._jrdd_deserializer - if isinstance(der, PickleSerializer): - return True - if isinstance(der, BatchedSerializer) and isinstance(der.serializer, PickleSerializer): - return True - return False - def _to_java_object_rdd(self): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the RDD is serialized in batch or not. """ - rdd = self._reserialize(AutoBatchedSerializer(PickleSerializer())) \ - if not self._is_pickled() else self - is_batch = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, is_batch) + rdd = self._pickled() + return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True) def countApprox(self, timeout, confidence=0.95): """ @@ -2008,7 +2020,7 @@ def countApproxDistinct(self, relativeSD=0.05): of The Art Cardinality Estimation Algorithm", available here. - @param relativeSD Relative accuracy. Smaller values create + :param relativeSD: Relative accuracy. Smaller values create counters that require more space. It must be greater than 0.000017. @@ -2131,7 +2143,7 @@ def _test(): globs = globals().copy() # The small batch size here ensures that we see multiple batches, # even in these small test examples: - globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + globs['sc'] = SparkContext('local[4]', 'PythonTest') (failure_count, test_count) = doctest.testmod( globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py index 55e247da0e4dc..459e1427803cb 100644 --- a/python/pyspark/rddsampler.py +++ b/python/pyspark/rddsampler.py @@ -17,82 +17,48 @@ import sys import random +import math class RDDSamplerBase(object): def __init__(self, withReplacement, seed=None): - try: - import numpy - self._use_numpy = True - except ImportError: - print >> sys.stderr, ( - "NumPy does not appear to be installed. " - "Falling back to default random generator for sampling.") - self._use_numpy = False - self._seed = seed if seed is not None else random.randint(0, sys.maxint) self._withReplacement = withReplacement self._random = None - self._split = None - self._rand_initialized = False def initRandomGenerator(self, split): - if self._use_numpy: - import numpy - self._random = numpy.random.RandomState(self._seed) - else: - self._random = random.Random(self._seed) - - for _ in range(0, split): - # discard the next few values in the sequence to have a - # different seed for the different splits - self._random.randint(0, sys.maxint) - - self._split = split - self._rand_initialized = True - - def getUniformSample(self, split): - if not self._rand_initialized or split != self._split: - self.initRandomGenerator(split) - - if self._use_numpy: - return self._random.random_sample() + self._random = random.Random(self._seed ^ split) + + # mixing because the initial seeds are close to each other + for _ in xrange(10): + self._random.randint(0, 1) + + def getUniformSample(self): + return self._random.random() + + def getPoissonSample(self, mean): + # Using Knuth's algorithm described in + # http://en.wikipedia.org/wiki/Poisson_distribution + if mean < 20.0: + # one exp and k+1 random calls + l = math.exp(-mean) + p = self._random.random() + k = 0 + while p > l: + k += 1 + p *= self._random.random() else: - return self._random.uniform(0.0, 1.0) - - def getPoissonSample(self, split, mean): - if not self._rand_initialized or split != self._split: - self.initRandomGenerator(split) - - if self._use_numpy: - return self._random.poisson(mean) - else: - # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by - # drawing a sequence of numbers delta_j ~ Exp(mean) - num_arrivals = 1 - cur_time = 0.0 - - cur_time += self._random.expovariate(mean) - - if cur_time > 1.0: - return 0 + # switch to the log domain, k+1 expovariate (random + log) calls + p = self._random.expovariate(mean) + k = 0 + while p < 1.0: + k += 1 + p += self._random.expovariate(mean) + return k - while(cur_time <= 1.0): - cur_time += self._random.expovariate(mean) - num_arrivals += 1 - - return (num_arrivals - 1) - - def shuffle(self, vals): - if self._random is None: - self.initRandomGenerator(0) # this should only ever called on the master so - # the split does not matter - - if self._use_numpy: - self._random.shuffle(vals) - else: - self._random.shuffle(vals, self._random.random) + def func(self, split, iterator): + raise NotImplementedError class RDDSampler(RDDSamplerBase): @@ -102,20 +68,35 @@ def __init__(self, withReplacement, fraction, seed=None): self._fraction = fraction def func(self, split, iterator): + self.initRandomGenerator(split) if self._withReplacement: for obj in iterator: # For large datasets, the expected number of occurrences of each element in # a sample with replacement is Poisson(frac). We use that to get a count for # each element. - count = self.getPoissonSample(split, mean=self._fraction) + count = self.getPoissonSample(self._fraction) for _ in range(0, count): yield obj else: for obj in iterator: - if self.getUniformSample(split) <= self._fraction: + if self.getUniformSample() < self._fraction: yield obj +class RDDRangeSampler(RDDSamplerBase): + + def __init__(self, lowerBound, upperBound, seed=None): + RDDSamplerBase.__init__(self, False, seed) + self._lowerBound = lowerBound + self._upperBound = upperBound + + def func(self, split, iterator): + self.initRandomGenerator(split) + for obj in iterator: + if self._lowerBound <= self.getUniformSample() < self._upperBound: + yield obj + + class RDDStratifiedSampler(RDDSamplerBase): def __init__(self, withReplacement, fractions, seed=None): @@ -123,15 +104,16 @@ def __init__(self, withReplacement, fractions, seed=None): self._fractions = fractions def func(self, split, iterator): + self.initRandomGenerator(split) if self._withReplacement: for key, val in iterator: # For large datasets, the expected number of occurrences of each element in # a sample with replacement is Poisson(frac). We use that to get a count for # each element. - count = self.getPoissonSample(split, mean=self._fractions[key]) + count = self.getPoissonSample(self._fractions[key]) for _ in range(0, count): yield key, val else: for key, val in iterator: - if self.getUniformSample(split) <= self._fractions[key]: + if self.getUniformSample() < self._fractions[key]: yield key, val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2672da36c1f50..33aa55f7f1429 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -33,9 +33,8 @@ [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] >>> sc.stop() -By default, PySpark serialize objects in batches; the batch size can be -controlled through SparkContext's C{batchSize} parameter -(the default size is 1024 objects): +PySpark serialize objects in batches; By default, the batch size is chosen based +on the size of objects, also configurable by SparkContext's C{batchSize} parameter: >>> sc = SparkContext('local', 'test', batchSize=2) >>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) @@ -48,16 +47,6 @@ >>> rdd._jrdd.count() 8L >>> sc.stop() - -A batch size of -1 uses an unlimited batch size, and a size of 1 disables -batching: - ->>> sc = SparkContext('local', 'test', batchSize=1) ->>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) ->>> rdd.glom().collect() -[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -16L """ import cPickle @@ -73,13 +62,14 @@ from pyspark import cloudpickle -__all__ = ["PickleSerializer", "MarshalSerializer"] +__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"] class SpecialLengths(object): END_OF_DATA_SECTION = -1 PYTHON_EXCEPTION_THROWN = -2 TIMING_DATA = -3 + END_OF_STREAM = -4 class Serializer(object): @@ -112,7 +102,10 @@ def __ne__(self, other): return not self.__eq__(other) def __repr__(self): - return "<%s object>" % self.__class__.__name__ + return "%s()" % self.__class__.__name__ + + def __hash__(self): + return hash(str(self)) class FramedSerializer(Serializer): @@ -140,6 +133,8 @@ def load_stream(self, stream): def _write_with_length(self, obj, stream): serialized = self.dumps(obj) + if len(serialized) > (1 << 31): + raise ValueError("can not serialize object larger than 2G") write_int(len(serialized), stream) if self._only_write_strings: stream.write(str(serialized)) @@ -177,6 +172,7 @@ class BatchedSerializer(Serializer): """ UNLIMITED_BATCH_SIZE = -1 + UNKNOWN_BATCH_SIZE = 0 def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): self.serializer = serializer @@ -209,10 +205,10 @@ def _load_stream_without_unbatching(self, stream): def __eq__(self, other): return (isinstance(other, BatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.batchSize == self.batchSize) - def __str__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + def __repr__(self): + return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize) class AutoBatchedSerializer(BatchedSerializer): @@ -220,8 +216,8 @@ class AutoBatchedSerializer(BatchedSerializer): Choose the size of batch automatically based on the size of object """ - def __init__(self, serializer, bestSize=1 << 20): - BatchedSerializer.__init__(self, serializer, -1) + def __init__(self, serializer, bestSize=1 << 16): + BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE) self.bestSize = bestSize def dump_stream(self, iterator, stream): @@ -244,10 +240,10 @@ def dump_stream(self, iterator, stream): def __eq__(self, other): return (isinstance(other, AutoBatchedSerializer) and - other.serializer == self.serializer) + other.serializer == self.serializer and other.bestSize == self.bestSize) def __str__(self): - return "BatchedSerializer<%s>" % str(self.serializer) + return "AutoBatchedSerializer(%s)" % str(self.serializer) class CartesianDeserializer(FramedSerializer): @@ -279,8 +275,8 @@ def __eq__(self, other): return (isinstance(other, CartesianDeserializer) and self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __str__(self): - return "CartesianDeserializer<%s, %s>" % \ + def __repr__(self): + return "CartesianDeserializer(%s, %s)" % \ (str(self.key_ser), str(self.val_ser)) @@ -306,8 +302,8 @@ def __eq__(self, other): return (isinstance(other, PairDeserializer) and self.key_ser == other.key_ser and self.val_ser == other.val_ser) - def __str__(self): - return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser)) + def __repr__(self): + return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser)) class NoOpSerializer(FramedSerializer): @@ -426,7 +422,7 @@ def loads(self, obj): class AutoSerializer(FramedSerializer): """ - Choose marshal or cPickle as serialization protocol autumatically + Choose marshal or cPickle as serialization protocol automatically """ def __init__(self): @@ -456,9 +452,9 @@ class CompressedSerializer(FramedSerializer): """ Compress the serialized data """ - def __init__(self, serializer): FramedSerializer.__init__(self) + assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer" self.serializer = serializer def dumps(self, obj): @@ -523,3 +519,8 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) stream.write(obj) + + +if __name__ == '__main__': + import doctest + doctest.testmod() diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index ce597cbe91e15..10a7ccd502000 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -25,7 +25,7 @@ import random import pyspark.heapq3 as heapq -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import AutoBatchedSerializer, PickleSerializer try: import psutil @@ -213,8 +213,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, Merger.__init__(self, aggregator) self.memory_limit = memory_limit # default serializer is only used for tests - self.serializer = serializer or \ - BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions @@ -396,7 +395,6 @@ def _external_items(self): for v in self.data.iteritems(): yield v self.data.clear() - gc.collect() # remove the merged partition for j in range(self.spills): @@ -428,7 +426,7 @@ def _recursive_merged_items(self, start): subdirs = [os.path.join(d, "parts", str(i)) for d in self.localdirs] m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions) + subdirs, self.scale * self.partitions, self.partitions) m.pdata = [{} for _ in range(self.partitions)] limit = self._next_limit() @@ -471,7 +469,7 @@ class ExternalSorter(object): def __init__(self, memory_limit, serializer=None): self.memory_limit = memory_limit self.local_dirs = _get_local_dirs("sort") - self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or AutoBatchedSerializer(PickleSerializer()) def _get_path(self, n): """ Choose one directory for spill by number n """ @@ -480,13 +478,21 @@ def _get_path(self, n): os.makedirs(d) return os.path.join(d, str(n)) + def _next_limit(self): + """ + Return the next memory limit. If the memory is not released + after spilling, it will dump the data only when the used memory + starts to increase. + """ + return max(self.memory_limit, get_used_memory() * 1.05) + def sorted(self, iterator, key=None, reverse=False): """ Sort the elements in iterator, do external sort when the memory goes above the limit. """ global MemoryBytesSpilled, DiskBytesSpilled - batch = 10 + batch, limit = 100, self._next_limit() chunks, current_chunk = [], [] iterator = iter(iterator) while True: @@ -506,6 +512,7 @@ def sorted(self, iterator, key=None, reverse=False): chunks.append(self.serializer.load_stream(open(path))) current_chunk = [] gc.collect() + limit = self._next_limit() MemoryBytesSpilled += (used_memory - get_used_memory()) << 20 DiskBytesSpilled += os.path.getsize(path) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 974b5e287bc00..ae288471b0e51 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -15,31 +15,43 @@ # limitations under the License. # +""" +public classes of Spark SQL: + + - L{SQLContext} + Main entry point for SQL functionality. + - L{SchemaRDD} + A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In + addition to normal RDD operations, SchemaRDDs also support SQL. + - L{Row} + A Row of data returned by a Spark SQL query. + - L{HiveContext} + Main entry point for accessing data stored in Apache Hive.. +""" -import sys -import types import itertools -import warnings import decimal import datetime import keyword import warnings +import json +import re from array import array from operator import itemgetter +from itertools import imap + +from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter from pyspark.rdd import RDD -from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer +from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ + CloudPickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync -from itertools import chain, ifilter, imap - -from py4j.protocol import Py4JError -from py4j.java_collections import ListConverter, MapConverter - __all__ = [ - "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", + "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", "SQLContext", "HiveContext", "SchemaRDD", "Row"] @@ -62,6 +74,18 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @classmethod + def typeName(cls): + return cls.__name__[:-4].lower() + + def jsonValue(self): + return self.typeName() + + def json(self): + return json.dumps(self.jsonValue(), + separators=(',', ':'), + sort_keys=True) + class PrimitiveTypeSingleton(type): @@ -86,6 +110,15 @@ def __eq__(self, other): return self is other +class NullType(PrimitiveType): + + """Spark SQL NullType + + The data type representing None, used for the types which has not + been inferred. + """ + + class StringType(PrimitiveType): """Spark SQL StringType @@ -110,6 +143,14 @@ class BooleanType(PrimitiveType): """ +class DateType(PrimitiveType): + + """Spark SQL DateType + + The data type representing datetime.date values. + """ + + class TimestampType(PrimitiveType): """Spark SQL TimestampType @@ -118,13 +159,30 @@ class TimestampType(PrimitiveType): """ -class DecimalType(PrimitiveType): +class DecimalType(DataType): """Spark SQL DecimalType The data type representing decimal.Decimal values. """ + def __init__(self, precision=None, scale=None): + self.precision = precision + self.scale = scale + self.hasPrecisionInfo = precision is not None + + def jsonValue(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal" + + def __repr__(self): + if self.hasPrecisionInfo: + return "DecimalType(%d,%d)" % (self.precision, self.scale) + else: + return "DecimalType()" + class DoubleType(PrimitiveType): @@ -201,10 +259,20 @@ def __init__(self, elementType, containsNull=True): self.elementType = elementType self.containsNull = containsNull - def __str__(self): + def __repr__(self): return "ArrayType(%s,%s)" % (self.elementType, str(self.containsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "elementType": self.elementType.jsonValue(), + "containsNull": self.containsNull} + + @classmethod + def fromJson(cls, json): + return ArrayType(_parse_datatype_json_value(json["elementType"]), + json["containsNull"]) + class MapType(DataType): @@ -245,6 +313,18 @@ def __repr__(self): return "MapType(%s,%s,%s)" % (self.keyType, self.valueType, str(self.valueContainsNull).lower()) + def jsonValue(self): + return {"type": self.typeName(), + "keyType": self.keyType.jsonValue(), + "valueType": self.valueType.jsonValue(), + "valueContainsNull": self.valueContainsNull} + + @classmethod + def fromJson(cls, json): + return MapType(_parse_datatype_json_value(json["keyType"]), + _parse_datatype_json_value(json["valueType"]), + json["valueContainsNull"]) + class StructField(DataType): @@ -261,12 +341,15 @@ class StructField(DataType): """ - def __init__(self, name, dataType, nullable): + def __init__(self, name, dataType, nullable=True, metadata=None): """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. :param nullable: indicates whether values of this field can be null. + :param metadata: metadata of this field, which is a map from string + to simple type that can be serialized to JSON + automatically >>> (StructField("f1", StringType, True) ... == StructField("f1", StringType, True)) @@ -278,11 +361,25 @@ def __init__(self, name, dataType, nullable): self.name = name self.dataType = dataType self.nullable = nullable + self.metadata = metadata or {} def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, str(self.nullable).lower()) + def jsonValue(self): + return {"name": self.name, + "type": self.dataType.jsonValue(), + "nullable": self.nullable, + "metadata": self.metadata} + + @classmethod + def fromJson(cls, json): + return StructField(json["name"], + _parse_datatype_json_value(json["type"]), + json["nullable"], + json["metadata"]) + class StructType(DataType): @@ -312,42 +409,99 @@ def __repr__(self): return ("StructType(List(%s))" % ",".join(str(field) for field in self.fields)) + def jsonValue(self): + return {"type": self.typeName(), + "fields": [f.jsonValue() for f in self.fields]} -def _parse_datatype_list(datatype_list_string): - """Parses a list of comma separated data types.""" - index = 0 - datatype_list = [] - start = 0 - depth = 0 - while index < len(datatype_list_string): - if depth == 0 and datatype_list_string[index] == ",": - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - start = index + 1 - elif datatype_list_string[index] == "(": - depth += 1 - elif datatype_list_string[index] == ")": - depth -= 1 + @classmethod + def fromJson(cls, json): + return StructType([StructField.fromJson(f) for f in json["fields"]]) - index += 1 - # Handle the last data type - datatype_string = datatype_list_string[start:index].strip() - datatype_list.append(_parse_datatype_string(datatype_string)) - return datatype_list +class UserDefinedType(DataType): + """ + :: WARN: Spark Internal Use Only :: + SQL User-Defined Type (UDT). + """ + + @classmethod + def typeName(cls): + return cls.__name__.lower() + + @classmethod + def sqlType(cls): + """ + Underlying SQL storage type for this UDT. + """ + raise NotImplementedError("UDT must implement sqlType().") + + @classmethod + def module(cls): + """ + The Python module of the UDT. + """ + raise NotImplementedError("UDT must implement module().") + + @classmethod + def scalaUDT(cls): + """ + The class name of the paired Scala UDT. + """ + raise NotImplementedError("UDT must have a paired Scala UDT.") + + def serialize(self, obj): + """ + Converts the a user-type object into a SQL datum. + """ + raise NotImplementedError("UDT must implement serialize().") + + def deserialize(self, datum): + """ + Converts a SQL datum into a user-type object. + """ + raise NotImplementedError("UDT must implement deserialize().") + + def json(self): + return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True) + + def jsonValue(self): + schema = { + "type": "udt", + "class": self.scalaUDT(), + "pyClass": "%s.%s" % (self.module(), type(self).__name__), + "sqlType": self.sqlType().jsonValue() + } + return schema + + @classmethod + def fromJson(cls, json): + pyUDT = json["pyClass"] + split = pyUDT.rfind(".") + pyModule = pyUDT[:split] + pyClass = pyUDT[split+1:] + m = __import__(pyModule, globals(), locals(), [pyClass], -1) + UDT = getattr(m, pyClass) + return UDT() + + def __eq__(self, other): + return type(self) == type(other) -_all_primitive_types = dict((k, v) for k, v in globals().iteritems() - if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType) +_all_primitive_types = dict((v.typeName(), v) + for v in globals().itervalues() + if type(v) is PrimitiveTypeSingleton and + v.__base__ == PrimitiveType) -def _parse_datatype_string(datatype_string): - """Parses the given data type string. +_all_complex_types = dict((v.typeName(), v) + for v in [ArrayType, MapType, StructType]) + +def _parse_datatype_json_string(json_string): + """Parses the given data type JSON string. >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype)) - ... python_datatype = _parse_datatype_string( - ... scala_datatype.toString()) + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) + ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... return datatype == python_datatype >>> all(check_datatype(cls()) for cls in _all_primitive_types.values()) True @@ -372,7 +526,8 @@ def _parse_datatype_string(datatype_string): ... StructField("simpleArray", simple_arraytype, True), ... StructField("simpleMap", simple_maptype, True), ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False)]) + ... StructField("boolean", BooleanType(), False), + ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) True >>> # Complex ArrayType. @@ -384,56 +539,43 @@ def _parse_datatype_string(datatype_string): ... complex_arraytype, False) >>> check_datatype(complex_maptype) True + >>> check_datatype(ExamplePointUDT()) + True + >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> check_datatype(structtype_with_udt) + True """ - index = datatype_string.find("(") - if index == -1: - # It is a primitive type. - index = len(datatype_string) - type_or_field = datatype_string[:index] - rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip() - - if type_or_field in _all_primitive_types: - return _all_primitive_types[type_or_field]() - - elif type_or_field == "ArrayType": - last_comma_index = rest_part.rfind(",") - containsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - containsNull = False - elementType = _parse_datatype_string( - rest_part[:last_comma_index].strip()) - return ArrayType(elementType, containsNull) - - elif type_or_field == "MapType": - last_comma_index = rest_part.rfind(",") - valueContainsNull = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - valueContainsNull = False - keyType, valueType = _parse_datatype_list( - rest_part[:last_comma_index].strip()) - return MapType(keyType, valueType, valueContainsNull) - - elif type_or_field == "StructField": - first_comma_index = rest_part.find(",") - name = rest_part[:first_comma_index].strip() - last_comma_index = rest_part.rfind(",") - nullable = True - if rest_part[last_comma_index + 1:].strip().lower() == "false": - nullable = False - dataType = _parse_datatype_string( - rest_part[first_comma_index + 1:last_comma_index].strip()) - return StructField(name, dataType, nullable) - - elif type_or_field == "StructType": - # rest_part should be in the format like - # List(StructField(field1,IntegerType,false)). - field_list_string = rest_part[rest_part.find("(") + 1:-1] - fields = _parse_datatype_list(field_list_string) - return StructType(fields) + return _parse_datatype_json_value(json.loads(json_string)) + + +_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") -# Mapping Python types to Spark SQL DateType +def _parse_datatype_json_value(json_value): + if type(json_value) is unicode: + if json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + elif json_value == u'decimal': + return DecimalType() + elif _FIXED_DECIMAL.match(json_value): + m = _FIXED_DECIMAL.match(json_value) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % json_value) + else: + tpe = json_value["type"] + if tpe in _all_complex_types: + return _all_complex_types[tpe].fromJson(json_value) + elif tpe == 'udt': + return UserDefinedType.fromJson(json_value) + else: + raise ValueError("not supported type: %s" % tpe) + + +# Mapping Python types to Spark SQL DataType _type_mappings = { + type(None): NullType, bool: BooleanType, int: IntegerType, long: LongType, @@ -442,30 +584,41 @@ def _parse_datatype_string(datatype_string): unicode: StringType, bytearray: BinaryType, decimal.Decimal: DecimalType, + datetime.date: DateType, datetime.datetime: TimestampType, - datetime.date: TimestampType, datetime.time: TimestampType, } def _infer_type(obj): - """Infer the DataType from obj""" + """Infer the DataType from obj + + >>> p = ExamplePoint(1.0, 2.0) + >>> _infer_type(p) + ExamplePointUDT + """ if obj is None: raise ValueError("Can not infer type for None") + if hasattr(obj, '__UDT__'): + return obj.__UDT__ + dataType = _type_mappings.get(type(obj)) if dataType is not None: return dataType() if isinstance(obj, dict): - if not obj: - raise ValueError("Can not infer type for empty dict") - key, value = obj.iteritems().next() - return MapType(_infer_type(key), _infer_type(value), True) + for key, value in obj.iteritems(): + if key is not None and value is not None: + return MapType(_infer_type(key), _infer_type(value), True) + else: + return MapType(NullType(), NullType(), True) elif isinstance(obj, (list, array)): - if not obj: - raise ValueError("Can not infer type for empty list/array") - return ArrayType(_infer_type(obj[0]), True) + for v in obj: + if v is not None: + return ArrayType(_infer_type(obj[0]), True) + else: + return ArrayType(NullType(), True) else: try: return _infer_schema(obj) @@ -498,60 +651,180 @@ def _infer_schema(row): return StructType(fields) -def _create_converter(obj, dataType): +def _need_python_to_sql_conversion(dataType): + """ + Checks whether we need python to sql conversion for the given type. + For now, only UDTs need this conversion. + + >>> _need_python_to_sql_conversion(DoubleType()) + False + >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), + ... StructField("values", ArrayType(DoubleType(), False), False)]) + >>> _need_python_to_sql_conversion(schema0) + False + >>> _need_python_to_sql_conversion(ExamplePointUDT()) + True + >>> schema1 = ArrayType(ExamplePointUDT(), False) + >>> _need_python_to_sql_conversion(schema1) + True + >>> schema2 = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> _need_python_to_sql_conversion(schema2) + True + """ + if isinstance(dataType, StructType): + return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + elif isinstance(dataType, ArrayType): + return _need_python_to_sql_conversion(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_python_to_sql_conversion(dataType.keyType) or \ + _need_python_to_sql_conversion(dataType.valueType) + elif isinstance(dataType, UserDefinedType): + return True + else: + return False + + +def _python_to_sql_converter(dataType): + """ + Returns a converter that converts a Python object into a SQL datum for the given type. + + >>> conv = _python_to_sql_converter(DoubleType()) + >>> conv(1.0) + 1.0 + >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) + >>> conv([1.0, 2.0]) + [1.0, 2.0] + >>> conv = _python_to_sql_converter(ExamplePointUDT()) + >>> conv(ExamplePoint(1.0, 2.0)) + [1.0, 2.0] + >>> schema = StructType([StructField("label", DoubleType(), False), + ... StructField("point", ExamplePointUDT(), False)]) + >>> conv = _python_to_sql_converter(schema) + >>> conv((1.0, ExamplePoint(1.0, 2.0))) + (1.0, [1.0, 2.0]) + """ + if not _need_python_to_sql_conversion(dataType): + return lambda x: x + + if isinstance(dataType, StructType): + names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) + converters = map(_python_to_sql_converter, types) + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"): + return tuple(c(v) for c, v in zip(converters, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs + d = dict(obj) + return tuple(c(d.get(n)) for n, c in zip(names, converters)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + else: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return converter + elif isinstance(dataType, ArrayType): + element_converter = _python_to_sql_converter(dataType.elementType) + return lambda a: [element_converter(v) for v in a] + elif isinstance(dataType, MapType): + key_converter = _python_to_sql_converter(dataType.keyType) + value_converter = _python_to_sql_converter(dataType.valueType) + return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): + return lambda obj: dataType.serialize(obj) + else: + raise ValueError("Unexpected type %r" % dataType) + + +def _has_nulltype(dt): + """ Return whether there is NullType in `dt` or not """ + if isinstance(dt, StructType): + return any(_has_nulltype(f.dataType) for f in dt.fields) + elif isinstance(dt, ArrayType): + return _has_nulltype((dt.elementType)) + elif isinstance(dt, MapType): + return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType) + else: + return isinstance(dt, NullType) + + +def _merge_type(a, b): + if isinstance(a, NullType): + return b + elif isinstance(b, NullType): + return a + elif type(a) is not type(b): + # TODO: type cast (such as int -> long) + raise TypeError("Can not merge type %s and %s" % (a, b)) + + # same type + if isinstance(a, StructType): + nfs = dict((f.name, f.dataType) for f in b.fields) + fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType()))) + for f in a.fields] + names = set([f.name for f in fields]) + for n in nfs: + if n not in names: + fields.append(StructField(n, nfs[n])) + return StructType(fields) + + elif isinstance(a, ArrayType): + return ArrayType(_merge_type(a.elementType, b.elementType), True) + + elif isinstance(a, MapType): + return MapType(_merge_type(a.keyType, b.keyType), + _merge_type(a.valueType, b.valueType), + True) + else: + return a + + +def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ if isinstance(dataType, ArrayType): - conv = _create_converter(obj[0], dataType.elementType) + conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) elif isinstance(dataType, MapType): - value = obj.values()[0] - conv = _create_converter(value, dataType.valueType) + conv = _create_converter(dataType.valueType) return lambda row: dict((k, conv(v)) for k, v in row.iteritems()) + elif isinstance(dataType, NullType): + return lambda x: None + elif not isinstance(dataType, StructType): return lambda x: x # dataType must be StructType names = [f.name for f in dataType.fields] + converters = [_create_converter(f.dataType) for f in dataType.fields] + + def convert_struct(obj): + if obj is None: + return + + if isinstance(obj, tuple): + if hasattr(obj, "fields"): + d = dict(zip(obj.fields, obj)) + if hasattr(obj, "__FIELDS__"): + d = dict(zip(obj.__FIELDS__, obj)) + elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): + d = dict(obj) + else: + raise ValueError("unexpected tuple: %s" % obj) - if isinstance(obj, dict): - conv = lambda o: tuple(o.get(n) for n in names) - - elif isinstance(obj, tuple): - if hasattr(obj, "_fields"): # namedtuple - conv = tuple - elif hasattr(obj, "__FIELDS__"): - conv = tuple - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): - conv = lambda o: tuple(v for k, v in o) + elif isinstance(obj, dict): + d = obj + elif hasattr(obj, "__dict__"): # object + d = obj.__dict__ else: - raise ValueError("unexpected tuple") + raise ValueError("Unexpected obj: %s" % obj) - elif hasattr(obj, "__dict__"): # object - conv = lambda o: [o.__dict__.get(n, None) for n in names] + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) - if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields): - return conv - - row = conv(obj) - convs = [_create_converter(v, f.dataType) - for v, f in zip(row, dataType.fields)] - - def nested_conv(row): - return tuple(f(v) for f, v in zip(convs, conv(row))) - - return nested_conv - - -def _drop_schema(rows, schema): - """ all the names of fields, becoming tuples""" - iterator = iter(rows) - row = iterator.next() - converter = _create_converter(row, schema) - yield converter(row) - for i in iterator: - yield converter(i) + return convert_struct _BRACKETS = {'(': ')', '[': ']', '{': '}'} @@ -650,10 +923,10 @@ def _infer_schema_type(obj, dataType): """ Fill the dataType with types infered from obj - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") + >>> schema = _parse_schema_abstract("a b c d") + >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType... + StructType...IntegerType...DoubleType...StringType...DateType... >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> _infer_schema_type(row, schema) @@ -663,7 +936,7 @@ def _infer_schema_type(obj, dataType): return _infer_type(obj) if not obj: - raise ValueError("Can not infer type from empty value") + return NullType() if isinstance(dataType, ArrayType): eType = _infer_schema_type(obj[0], dataType.elementType) @@ -697,6 +970,7 @@ def _infer_schema_type(obj, dataType): DecimalType: (decimal.Decimal,), StringType: (str, unicode), BinaryType: (bytearray,), + DateType: (datetime.date,), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), @@ -724,17 +998,28 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT()) + >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ # all objects are nullable if obj is None: return + if isinstance(dataType, UserDefinedType): + if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): + raise ValueError("%r is not an instance of type %r" % (obj, dataType)) + _verify_type(dataType.serialize(obj), dataType.sqlType()) + return + _type = type(dataType) assert _type in _acceptable_types, "unkown datatype: %s" % dataType # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept abject in type %s" + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): @@ -761,7 +1046,7 @@ def _restore_object(dataType, obj): """ Restore object during unpickling. """ # use id(dataType) as key to speed up lookup in dict # Because of batched pickling, dataType will be the - # same object in mose cases. + # same object in most cases. k = id(dataType) cls = _cached_cls.get(k) if cls is None: @@ -776,6 +1061,10 @@ def _restore_object(dataType, obj): def _create_object(cls, v): """ Create an customized object with class `cls`. """ + # datetime.date would be deserialized as datetime.datetime + # from java type, so we need to set it back. + if cls is datetime.date and isinstance(v, datetime.datetime): + return v.date() return cls(v) if v is not None else v @@ -789,14 +1078,18 @@ def getter(self): return getter -def _has_struct(dt): - """Return whether `dt` is or has StructType in it""" +def _has_struct_or_date(dt): + """Return whether `dt` is or has StructType/DateType in it""" if isinstance(dt, StructType): return True elif isinstance(dt, ArrayType): - return _has_struct(dt.elementType) + return _has_struct_or_date(dt.elementType) elif isinstance(dt, MapType): - return _has_struct(dt.valueType) + return _has_struct_or_date(dt.valueType) + elif isinstance(dt, DateType): + return True + elif isinstance(dt, UserDefinedType): + return True return False @@ -809,7 +1102,7 @@ def _create_properties(fields): or keyword.iskeyword(name)): warnings.warn("field name %s can not be accessed in Python," "use position to access it instead" % name) - if _has_struct(f.dataType): + if _has_struct_or_date(f.dataType): # delay creating object until accessing it getter = _create_getter(f.dataType, i) else: @@ -864,6 +1157,12 @@ def Dict(d): return Dict + elif isinstance(dataType, DateType): + return datetime.date + + elif isinstance(dataType, UserDefinedType): + return lambda datum: dataType.deserialize(datum) + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -877,6 +1176,10 @@ class Row(tuple): # create property for fast access locals().update(_create_properties(dataType.fields)) + def asDict(self): + """ Return as a dict """ + return dict((n, getattr(self, n)) for n in self.__FIELDS__) + def __repr__(self): # call collect __repr__ for nested objects return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) @@ -899,8 +1202,8 @@ class SQLContext(object): def __init__(self, sparkContext, sqlContext=None): """Create a new SQLContext. - @param sparkContext: The SparkContext to wrap. - @param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. >>> srdd = sqlCtx.inferSchema(rdd) @@ -931,7 +1234,6 @@ def __init__(self, sparkContext, sqlContext=None): self._sc = sparkContext self._jsc = self._sc._jsc self._jvm = self._sc._jvm - self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray self._scala_SQLContext = sqlContext @property @@ -961,8 +1263,8 @@ def registerFunction(self, name, f, returnType=StringType()): """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, None, - BatchedSerializer(PickleSerializer(), 1024), - BatchedSerializer(PickleSerializer(), 1024)) + AutoBatchedSerializer(PickleSerializer()), + AutoBatchedSerializer(PickleSerializer())) ser = CloudPickleSerializer() pickled_command = ser.dumps(command) if len(pickled_command) > (1 << 20): # 1M @@ -983,20 +1285,22 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc.pythonExec, broadcast_vars, self._sc._javaAccumulator, - str(returnType)) + returnType.json()) - def inferSchema(self, rdd): + def inferSchema(self, rdd, samplingRatio=None): """Infer and apply a schema to an RDD of L{Row}. - We peek at the first row of the RDD to determine the fields' names - and types. Nested collections are supported, which include array, - dict, list, Row, tuple, namedtuple, or object. + When samplingRatio is specified, the schema is inferred by looking + at the types of each row in the sampled dataset. Otherwise, the + first 100 rows of the RDD are inspected. Nested collections are + supported, which can include array, dict, list, Row, tuple, + namedtuple, or object. - All the rows in `rdd` should have the same type with the first one, - or it will cause runtime exceptions. + Each row could be L{pyspark.sql.Row} object or namedtuple or objects. + Using top level dicts is deprecated, as dict is used to represent Maps. - Each row could be L{pyspark.sql.Row} object or namedtuple or objects, - using dict is deprecated. + If a single column has multiple distinct inferred types, it may cause + runtime exceptions. >>> rdd = sc.parallelize( ... [Row(field1=1, field2="row1"), @@ -1033,8 +1337,23 @@ def inferSchema(self, rdd): warnings.warn("Using RDD of dict to inferSchema is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) - rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) + if samplingRatio is None: + schema = _infer_schema(first) + if _has_nulltype(schema): + for row in rdd.take(100)[1:]: + schema = _merge_type(schema, _infer_schema(row)) + if not _has_nulltype(schema): + break + else: + warnings.warn("Some of types cannot be determined by the " + "first 100 rows, please try again with sampling") + else: + if samplingRatio > 0.99: + rdd = rdd.sample(False, float(samplingRatio)) + schema = rdd.map(_infer_schema).reduce(_merge_type) + + converter = _create_converter(schema) + rdd = rdd.map(converter) return self.applySchema(rdd, schema) def applySchema(self, rdd, schema): @@ -1058,8 +1377,9 @@ def applySchema(self, rdd, schema): >>> srdd2.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - >>> from datetime import datetime + >>> from datetime import date, datetime >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + ... date(2010, 1, 1), ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ @@ -1069,6 +1389,7 @@ def applySchema(self, rdd, schema): ... StructField("short2", ShortType(), False), ... StructField("int", IntegerType(), False), ... StructField("float", FloatType(), False), + ... StructField("date", DateType(), False), ... StructField("time", TimestampType(), False), ... StructField("map", ... MapType(StringType(), IntegerType(), False), False), @@ -1078,10 +1399,11 @@ def applySchema(self, rdd, schema): ... StructField("null", DoubleType(), True)]) >>> srdd = sqlCtx.applySchema(rdd, schema) >>> results = srdd.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, - ... x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] - (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, + ... x.time, x.map["a"], x.struct.b, x.list, x.null)) + >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE + (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), + datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) >>> srdd.registerTempTable("table2") >>> sqlCtx.sql( @@ -1117,9 +1439,12 @@ def applySchema(self, rdd, schema): for row in rows: _verify_type(row, schema) - batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer) - jrdd = self._pythonToJava(rdd._jrdd, batched) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema)) + # convert python objects to sql data + converter = _python_to_sql_converter(schema) + rdd = rdd.map(converter) + + jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) return SchemaRDD(srdd.toJavaSchemaRDD(), self) def registerRDDAsTable(self, rdd, tableName): @@ -1152,7 +1477,7 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD() return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path, schema=None): + def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a L{SchemaRDD}. @@ -1160,8 +1485,8 @@ def jsonFile(self, path, schema=None): If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -1207,20 +1532,20 @@ def jsonFile(self, path, schema=None): [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path) + srdd = self._ssql_ctx.jsonFile(path, samplingRatio) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) - def jsonRDD(self, rdd, schema=None): + def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it goes through the entire dataset once to determine - the schema. + Otherwise, it samples the dataset with ratio `samplingRatio` to + determine the schema. >>> srdd1 = sqlCtx.jsonRDD(json) >>> sqlCtx.registerRDDAsTable(srdd1, "table1") @@ -1277,9 +1602,9 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: - scala_datatype = self._ssql_ctx.parseDataType(str(schema)) + scala_datatype = self._ssql_ctx.parseDataType(schema.json()) srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(srdd.toJavaSchemaRDD(), self) @@ -1325,8 +1650,8 @@ class HiveContext(SQLContext): def __init__(self, sparkContext, hiveContext=None): """Create a new HiveContext. - @param sparkContext: The SparkContext to wrap. - @param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new + :param sparkContext: The SparkContext to wrap. + :param hiveContext: An optional JVM Scala HiveContext. If set, we do not instatiate a new HiveContext in the JVM, instead we make all calls to this object. """ SQLContext.__init__(self, sparkContext) @@ -1369,33 +1694,6 @@ def hql(self, hqlQuery): class LocalHiveContext(HiveContext): - """Starts up an instance of hive where metadata is stored locally. - - An in-process metadata data is created with data stored in ./metadata. - Warehouse data is stored in in ./warehouse. - - >>> import os - >>> hiveCtx = LocalHiveContext(sc) - >>> try: - ... supress = hiveCtx.sql("DROP TABLE src") - ... except Exception: - ... pass - >>> kv1 = os.path.join(os.environ["SPARK_HOME"], - ... 'examples/src/main/resources/kv1.txt') - >>> supress = hiveCtx.sql( - ... "CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - >>> supress = hiveCtx.sql("LOAD DATA LOCAL INPATH '%s' INTO TABLE src" - ... % kv1) - >>> results = hiveCtx.sql("FROM src SELECT value" - ... ).map(lambda r: int(r.value.split('_')[1])) - >>> num = results.count() - >>> reduce_sum = results.reduce(lambda x, y: x + y) - >>> num - 500 - >>> reduce_sum - 130091 - """ - def __init__(self, sparkContext, sqlContext=None): HiveContext.__init__(self, sparkContext, sqlContext) warnings.warn("LocalHiveContext is deprecated. " @@ -1460,6 +1758,14 @@ def __new__(self, *args, **kwargs): else: raise ValueError("No args or kwargs") + def asDict(self): + """ + Return as an dict + """ + if not hasattr(self, "__FIELDS__"): + raise TypeError("Cannot convert a Row class into dict") + return dict(zip(self.__FIELDS__, self)) + # let obect acs like class def __call__(self, *args): """create new Row object""" @@ -1534,7 +1840,7 @@ def __init__(self, jschema_rdd, sql_ctx): self.is_checkpointed = False self.ctx = self.sql_ctx._sc # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = BatchedSerializer(PickleSerializer()) + self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property def _jrdd(self): @@ -1564,6 +1870,21 @@ def limit(self, num): rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD() return SchemaRDD(rdd, self.sql_ctx) + def toJSON(self, use_unicode=False): + """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. + + >>> srdd1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") + >>> srdd2 = sqlCtx.sql( "SELECT * from table1") + >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' + True + >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") + >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] + True + """ + rdd = self._jschema_rdd.baseSchemaRDD().toJSON() + return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) + def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. @@ -1614,7 +1935,7 @@ def saveAsTable(self, tableName): def schema(self): """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" - return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString()) + return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) def schemaString(self): """Returns the output schema in the tree format.""" @@ -1764,15 +2085,13 @@ def subtract(self, other, numPartitions=None): def _test(): import doctest - from array import array from pyspark.context import SparkContext # let doctest run in pyspark.sql, so DataTypes can be picklable import pyspark.sql from pyspark.sql import Row, SQLContext + from pyspark.tests import ExamplePoint, ExamplePointUDT globs = pyspark.sql.__dict__.copy() - # The small batch size here ensures that we see multiple batches, - # even in these small test examples: - sc = SparkContext('local[4]', 'PythonTest', batchSize=2) + sc = SparkContext('local[4]', 'PythonTest') globs['sc'] = sc globs['sqlCtx'] = SQLContext(sc) globs['rdd'] = sc.parallelize( @@ -1780,6 +2099,8 @@ def _test(): Row(field1=2, field2="row2"), Row(field1=3, field2="row3")] ) + globs['ExamplePoint'] = ExamplePoint + globs['ExamplePointUDT'] = ExamplePointUDT jsonStrings = [ '{"field1": 1, "field2": "row1", "field3":{"field4":11}}', '{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},' diff --git a/python/pyspark/streaming/__init__.py b/python/pyspark/streaming/__init__.py new file mode 100644 index 0000000000000..d2644a1d4ffab --- /dev/null +++ b/python/pyspark/streaming/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.streaming.context import StreamingContext +from pyspark.streaming.dstream import DStream + +__all__ = ['StreamingContext', 'DStream'] diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py new file mode 100644 index 0000000000000..d48f3598e33b2 --- /dev/null +++ b/python/pyspark/streaming/context.py @@ -0,0 +1,325 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import sys + +from py4j.java_collections import ListConverter +from py4j.java_gateway import java_import, JavaObject + +from pyspark import RDD, SparkConf +from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer +from pyspark.context import SparkContext +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.dstream import DStream +from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer + +__all__ = ["StreamingContext"] + + +def _daemonize_callback_server(): + """ + Hack Py4J to daemonize callback server + + The thread of callback server has daemon=False, it will block the driver + from exiting if it's not shutdown. The following code replace `start()` + of CallbackServer with a new version, which set daemon=True for this + thread. + + Also, it will update the port number (0) with real port + """ + # TODO: create a patch for Py4J + import socket + import py4j.java_gateway + logger = py4j.java_gateway.logger + from py4j.java_gateway import Py4JNetworkError + from threading import Thread + + def start(self): + """Starts the CallbackServer. This method should be called by the + client instead of run().""" + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + 1) + try: + self.server_socket.bind((self.address, self.port)) + if not self.port: + # update port with real port + self.port = self.server_socket.getsockname()[1] + except Exception as e: + msg = 'An error occurred while trying to start the callback server: %s' % e + logger.exception(msg) + raise Py4JNetworkError(msg) + + # Maybe thread needs to be cleanup up? + self.thread = Thread(target=self.run) + self.thread.daemon = True + self.thread.start() + + py4j.java_gateway.CallbackServer.start = start + + +class StreamingContext(object): + """ + Main entry point for Spark Streaming functionality. A StreamingContext + represents the connection to a Spark cluster, and can be used to create + L{DStream} various input sources. It can be from an existing L{SparkContext}. + After creating and transforming DStreams, the streaming computation can + be started and stopped using `context.start()` and `context.stop()`, + respectively. `context.awaitTermination()` allows the current thread + to wait for the termination of the context by `stop()` or by an exception. + """ + _transformerSerializer = None + + def __init__(self, sparkContext, batchDuration=None, jssc=None): + """ + Create a new StreamingContext. + + @param sparkContext: L{SparkContext} object. + @param batchDuration: the time interval (in seconds) at which streaming + data will be divided into batches + """ + + self._sc = sparkContext + self._jvm = self._sc._jvm + self._jssc = jssc or self._initialize_context(self._sc, batchDuration) + + def _initialize_context(self, sc, duration): + self._ensure_initialized() + return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration)) + + def _jduration(self, seconds): + """ + Create Duration object given number of seconds + """ + return self._jvm.Duration(int(seconds * 1000)) + + @classmethod + def _ensure_initialized(cls): + SparkContext._ensure_initialized() + gw = SparkContext._gateway + + java_import(gw.jvm, "org.apache.spark.streaming.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.java.*") + java_import(gw.jvm, "org.apache.spark.streaming.api.python.*") + + # start callback server + # getattr will fallback to JVM, so we cannot test by hasattr() + if "_callback_server" not in gw.__dict__: + _daemonize_callback_server() + # use random port + gw._start_callback_server(0) + # gateway with real port + gw._python_proxy_port = gw._callback_server.port + # get the GatewayServer object in JVM by ID + jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) + # update the port of CallbackClient with real port + gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + + # register serializer for TransformFunction + # it happens before creating SparkContext when loading from checkpointing + cls._transformerSerializer = TransformFunctionSerializer( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) + + @classmethod + def getOrCreate(cls, checkpointPath, setupFunc): + """ + Either recreate a StreamingContext from checkpoint data or create a new StreamingContext. + If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be + recreated from the checkpoint data. If the data does not exist, then the provided setupFunc + will be used to create a JavaStreamingContext. + + @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program + @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams + """ + # TODO: support checkpoint in HDFS + if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath): + ssc = setupFunc() + ssc.checkpoint(checkpointPath) + return ssc + + cls._ensure_initialized() + gw = SparkContext._gateway + + try: + jssc = gw.jvm.JavaStreamingContext(checkpointPath) + except Exception: + print >>sys.stderr, "failed to load StreamingContext from checkpoint" + raise + + jsc = jssc.sparkContext() + conf = SparkConf(_jconf=jsc.getConf()) + sc = SparkContext(conf=conf, gateway=gw, jsc=jsc) + # update ctx in serializer + SparkContext._active_spark_context = sc + cls._transformerSerializer.ctx = sc + return StreamingContext(sc, None, jssc) + + @property + def sparkContext(self): + """ + Return SparkContext which is associated with this StreamingContext. + """ + return self._sc + + def start(self): + """ + Start the execution of the streams. + """ + self._jssc.start() + + def awaitTermination(self, timeout=None): + """ + Wait for the execution to stop. + @param timeout: time to wait in seconds + """ + if timeout is None: + self._jssc.awaitTermination() + else: + self._jssc.awaitTermination(int(timeout * 1000)) + + def stop(self, stopSparkContext=True, stopGraceFully=False): + """ + Stop the execution of the streams, with option of ensuring all + received data has been processed. + + @param stopSparkContext: Stop the associated SparkContext or not + @param stopGracefully: Stop gracefully by waiting for the processing + of all received data to be completed + """ + self._jssc.stop(stopSparkContext, stopGraceFully) + if stopSparkContext: + self._sc.stop() + + def remember(self, duration): + """ + Set each DStreams in this context to remember RDDs it generated + in the last given duration. DStreams remember RDDs only for a + limited duration of time and releases them for garbage collection. + This method allows the developer to specify how to long to remember + the RDDs (if the developer wishes to query old data outside the + DStream computation). + + @param duration: Minimum duration (in seconds) that each DStream + should remember its RDDs + """ + self._jssc.remember(self._jduration(duration)) + + def checkpoint(self, directory): + """ + Sets the context to periodically checkpoint the DStream operations for master + fault-tolerance. The graph will be checkpointed every batch interval. + + @param directory: HDFS-compatible directory where the checkpoint data + will be reliably stored + """ + self._jssc.checkpoint(directory) + + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + """ + Create an input from TCP source hostname:port. Data is received using + a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited + lines. + + @param hostname: Hostname to connect to for receiving data + @param port: Port to connect to for receiving data + @param storageLevel: Storage level to use for storing the received objects + """ + jlevel = self._sc._getJavaStorageLevel(storageLevel) + return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self, + UTF8Deserializer()) + + def textFileStream(self, directory): + """ + Create an input stream that monitors a Hadoop-compatible file system + for new files and reads them as text files. Files must be wrriten to the + monitored directory by "moving" them from another location within the same + file system. File names starting with . are ignored. + """ + return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) + + def _check_serializers(self, rdds): + # make sure they have same serializer + if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1: + for i in range(len(rdds)): + # reset them to sc.serializer + rdds[i] = rdds[i]._reserialize() + + def queueStream(self, rdds, oneAtATime=True, default=None): + """ + Create an input stream from an queue of RDDs or list. In each batch, + it will process either one or all of the RDDs returned by the queue. + + NOTE: changes to the queue after the stream is created will not be recognized. + + @param rdds: Queue of RDDs + @param oneAtATime: pick one rdd each time or pick all of them once. + @param default: The default rdd if no more in rdds + """ + if default and not isinstance(default, RDD): + default = self._sc.parallelize(default) + + if not rdds and default: + rdds = [rdds] + + if rdds and not isinstance(rdds[0], RDD): + rdds = [self._sc.parallelize(input) for input in rdds] + self._check_serializers(rdds) + + jrdds = ListConverter().convert([r._jrdd for r in rdds], + SparkContext._gateway._gateway_client) + queue = self._jvm.PythonDStream.toRDDQueue(jrdds) + if default: + default = default._reserialize(rdds[0]._jrdd_deserializer) + jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd) + else: + jdstream = self._jssc.queueStream(queue, oneAtATime) + return DStream(jdstream, self, rdds[0]._jrdd_deserializer) + + def transform(self, dstreams, transformFunc): + """ + Create a new DStream in which each RDD is generated by applying + a function on RDDs of the DStreams. The order of the JavaRDDs in + the transform function parameter will be the same as the order + of corresponding DStreams in the list. + """ + jdstreams = ListConverter().convert([d._jdstream for d in dstreams], + SparkContext._gateway._gateway_client) + # change the final serializer to sc.serializer + func = TransformFunction(self._sc, + lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + *[d._jrdd_deserializer for d in dstreams]) + jfunc = self._jvm.TransformFunction(func) + jdstream = self._jssc.transform(jdstreams, jfunc) + return DStream(jdstream, self, self._sc.serializer) + + def union(self, *dstreams): + """ + Create a unified DStream from multiple DStreams of the same + type and same slide duration. + """ + if not dstreams: + raise ValueError("should have at least one DStream to union") + if len(dstreams) == 1: + return dstreams[0] + if len(set(s._jrdd_deserializer for s in dstreams)) > 1: + raise ValueError("All DStreams should have same serializer") + if len(set(s._slideDuration for s in dstreams)) > 1: + raise ValueError("All DStreams should have same slide duration") + first = dstreams[0] + jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], + SparkContext._gateway._gateway_client) + return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py new file mode 100644 index 0000000000000..0826ddc56e844 --- /dev/null +++ b/python/pyspark/streaming/dstream.py @@ -0,0 +1,623 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from itertools import chain, ifilter, imap +import operator +import time +from datetime import datetime + +from py4j.protocol import Py4JJavaError + +from pyspark import RDD +from pyspark.storagelevel import StorageLevel +from pyspark.streaming.util import rddToFileName, TransformFunction +from pyspark.rdd import portable_hash +from pyspark.resultiterable import ResultIterable + +__all__ = ["DStream"] + + +class DStream(object): + """ + A Discretized Stream (DStream), the basic abstraction in Spark Streaming, + is a continuous sequence of RDDs (of the same type) representing a + continuous stream of data (see L{RDD} in the Spark core documentation + for more details on RDDs). + + DStreams can either be created from live data (such as, data from TCP + sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be + generated by transforming existing DStreams using operations such as + `map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming + program is running, each DStream periodically generates a RDD, either + from live data or by transforming the RDD generated by a parent DStream. + + DStreams internally is characterized by a few basic properties: + - A list of other DStreams that the DStream depends on + - A time interval at which the DStream generates an RDD + - A function that is used to generate an RDD after each time interval + """ + def __init__(self, jdstream, ssc, jrdd_deserializer): + self._jdstream = jdstream + self._ssc = ssc + self._sc = ssc._sc + self._jrdd_deserializer = jrdd_deserializer + self.is_cached = False + self.is_checkpointed = False + + def context(self): + """ + Return the StreamingContext associated with this DStream + """ + return self._ssc + + def count(self): + """ + Return a new DStream in which each RDD has a single element + generated by counting each RDD of this DStream. + """ + return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add) + + def filter(self, f): + """ + Return a new DStream containing only the elements that satisfy predicate. + """ + def func(iterator): + return ifilter(f, iterator) + return self.mapPartitions(func, True) + + def flatMap(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to all elements of + this DStream, and then flattening the results + """ + def func(s, iterator): + return chain.from_iterable(imap(f, iterator)) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def map(self, f, preservesPartitioning=False): + """ + Return a new DStream by applying a function to each element of DStream. + """ + def func(iterator): + return imap(f, iterator) + return self.mapPartitions(func, preservesPartitioning) + + def mapPartitions(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitions() to each RDDs of this DStream. + """ + def func(s, iterator): + return f(iterator) + return self.mapPartitionsWithIndex(func, preservesPartitioning) + + def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + """ + Return a new DStream in which each RDD is generated by applying + mapPartitionsWithIndex() to each RDDs of this DStream. + """ + return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning)) + + def reduce(self, func): + """ + Return a new DStream in which each RDD has a single element + generated by reducing each RDD of this DStream. + """ + return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1]) + + def reduceByKey(self, func, numPartitions=None): + """ + Return a new DStream by applying reduceByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.combineByKey(lambda x: x, func, func, numPartitions) + + def combineByKey(self, createCombiner, mergeValue, mergeCombiners, + numPartitions=None): + """ + Return a new DStream by applying combineByKey to each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def func(rdd): + return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions) + return self.transform(func) + + def partitionBy(self, numPartitions, partitionFunc=portable_hash): + """ + Return a copy of the DStream in which each RDD are partitioned + using the specified partitioner. + """ + return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) + + def foreachRDD(self, func): + """ + Apply a function to each RDD in this DStream. + """ + if func.func_code.co_argcount == 1: + old_func = func + func = lambda t, rdd: old_func(rdd) + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer) + api = self._ssc._jvm.PythonDStream + api.callForeachRDD(self._jdstream, jfunc) + + def pprint(self): + """ + Print the first ten elements of each RDD generated in this DStream. + """ + def takeAndPrint(time, rdd): + taken = rdd.take(11) + print "-------------------------------------------" + print "Time: %s" % time + print "-------------------------------------------" + for record in taken[:10]: + print record + if len(taken) > 10: + print "..." + print + + self.foreachRDD(takeAndPrint) + + def mapValues(self, f): + """ + Return a new DStream by applying a map function to the value of + each key-value pairs in this DStream without changing the key. + """ + map_values_fn = lambda (k, v): (k, f(v)) + return self.map(map_values_fn, preservesPartitioning=True) + + def flatMapValues(self, f): + """ + Return a new DStream by applying a flatmap function to the value + of each key-value pairs in this DStream without changing the key. + """ + flat_map_fn = lambda (k, v): ((k, x) for x in f(v)) + return self.flatMap(flat_map_fn, preservesPartitioning=True) + + def glom(self): + """ + Return a new DStream in which RDD is generated by applying glom() + to RDD of this DStream. + """ + def func(iterator): + yield list(iterator) + return self.mapPartitions(func) + + def cache(self): + """ + Persist the RDDs of this DStream with the default storage level + (C{MEMORY_ONLY_SER}). + """ + self.is_cached = True + self.persist(StorageLevel.MEMORY_ONLY_SER) + return self + + def persist(self, storageLevel): + """ + Persist the RDDs of this DStream with the given storage level + """ + self.is_cached = True + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdstream.persist(javaStorageLevel) + return self + + def checkpoint(self, interval): + """ + Enable periodic checkpointing of RDDs of this DStream + + @param interval: time in seconds, after each period of that, generated + RDD will be checkpointed + """ + self.is_checkpointed = True + self._jdstream.checkpoint(self._ssc._jduration(interval)) + return self + + def groupByKey(self, numPartitions=None): + """ + Return a new DStream by applying groupByKey on each RDD. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transform(lambda rdd: rdd.groupByKey(numPartitions)) + + def countByValue(self): + """ + Return a new DStream in which each RDD contains the counts of each + distinct value in each RDD of this DStream. + """ + return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() + + def saveAsTextFiles(self, prefix, suffix=None): + """ + Save each RDD in this DStream as at text file, using string + representation of elements. + """ + def saveAsTextFile(t, rdd): + path = rddToFileName(prefix, suffix, t) + try: + rdd.saveAsTextFile(path) + except Py4JJavaError as e: + # after recovered from checkpointing, the foreachRDD may + # be called twice + if 'FileAlreadyExistsException' not in str(e): + raise + return self.foreachRDD(saveAsTextFile) + + # TODO: uncomment this until we have ssc.pickleFileStream() + # def saveAsPickleFiles(self, prefix, suffix=None): + # """ + # Save each RDD in this DStream as at binary file, the elements are + # serialized by pickle. + # """ + # def saveAsPickleFile(t, rdd): + # path = rddToFileName(prefix, suffix, t) + # try: + # rdd.saveAsPickleFile(path) + # except Py4JJavaError as e: + # # after recovered from checkpointing, the foreachRDD may + # # be called twice + # if 'FileAlreadyExistsException' not in str(e): + # raise + # return self.foreachRDD(saveAsPickleFile) + + def transform(self, func): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream. + + `func` can have one argument of `rdd`, or have two arguments of + (`time`, `rdd`) + """ + if func.func_code.co_argcount == 1: + oldfunc = func + func = lambda t, rdd: oldfunc(rdd) + assert func.func_code.co_argcount == 2, "func should take one or two arguments" + return TransformedDStream(self, func) + + def transformWith(self, func, other, keepSerializer=False): + """ + Return a new DStream in which each RDD is generated by applying a function + on each RDD of this DStream and 'other' DStream. + + `func` can have two arguments of (`rdd_a`, `rdd_b`) or have three + arguments of (`time`, `rdd_a`, `rdd_b`) + """ + if func.func_code.co_argcount == 2: + oldfunc = func + func = lambda t, a, b: oldfunc(a, b) + assert func.func_code.co_argcount == 3, "func should take two or three arguments" + jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(), + other._jdstream.dstream(), jfunc) + jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer + return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer) + + def repartition(self, numPartitions): + """ + Return a new DStream with an increased or decreased level of parallelism. + """ + return self.transform(lambda rdd: rdd.repartition(numPartitions)) + + @property + def _slideDuration(self): + """ + Return the slideDuration in seconds of this DStream + """ + return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0 + + def union(self, other): + """ + Return a new DStream by unifying data of another DStream with this DStream. + + @param other: Another DStream having the same interval (i.e., slideDuration) + as this DStream. + """ + if self._slideDuration != other._slideDuration: + raise ValueError("the two DStream should have same slide duration") + return self.transformWith(lambda a, b: a.union(b), other, True) + + def cogroup(self, other, numPartitions=None): + """ + Return a new DStream by applying 'cogroup' between RDDs of this + DStream and `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other) + + def join(self, other, numPartitions=None): + """ + Return a new DStream by applying 'join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.join(b, numPartitions), other) + + def leftOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'left outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other) + + def rightOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'right outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other) + + def fullOuterJoin(self, other, numPartitions=None): + """ + Return a new DStream by applying 'full outer join' between RDDs of this DStream and + `other` DStream. + + Hash partitioning is used to generate the RDDs with `numPartitions` + partitions. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other) + + def _jtime(self, timestamp): + """ Convert datetime or unix_timestamp into Time + """ + if isinstance(timestamp, datetime): + timestamp = time.mktime(timestamp.timetuple()) + return self._sc._jvm.Time(long(timestamp * 1000)) + + def slice(self, begin, end): + """ + Return all the RDDs between 'begin' to 'end' (both included) + + `begin`, `end` could be datetime.datetime() or unix_timestamp + """ + jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end)) + return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds] + + def _validate_window_param(self, window, slide): + duration = self._jdstream.dstream().slideDuration().milliseconds() + if int(window * 1000) % duration != 0: + raise ValueError("windowDuration must be multiple of the slide duration (%d ms)" + % duration) + if slide and int(slide * 1000) % duration != 0: + raise ValueError("slideDuration must be multiple of the slide duration (%d ms)" + % duration) + + def window(self, windowDuration, slideDuration=None): + """ + Return a new DStream in which each RDD contains all the elements in seen in a + sliding window of time over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + self._validate_window_param(windowDuration, slideDuration) + d = self._ssc._jduration(windowDuration) + if slideDuration is None: + return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer) + s = self._ssc._jduration(slideDuration) + return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer) + + def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated by reducing all + elements in a sliding window over this DStream. + + if `invReduceFunc` is not None, the reduction is done incrementally + using the old window's reduced value : + + 1. reduce the new values that entered the window (e.g., adding new counts) + + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + This is more efficient than `invReduceFunc` is None. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse reduce function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + """ + keyed = self.map(lambda x: (1, x)) + reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc, + windowDuration, slideDuration, 1) + return reduced.map(lambda (k, v): v) + + def countByWindow(self, windowDuration, slideDuration): + """ + Return a new DStream in which each RDD has a single element generated + by counting the number of elements in a window over this DStream. + windowDuration and slideDuration are as defined in the window() operation. + + This is equivalent to window(windowDuration, slideDuration).count(), + but will be more efficient if window is large. + """ + return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub, + windowDuration, slideDuration) + + def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream in which each RDD contains the count of distinct elements in + RDDs in a sliding window over this DStream. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + """ + keyed = self.map(lambda x: (x, 1)) + counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, + windowDuration, slideDuration, numPartitions) + return counted.filter(lambda (k, v): v > 0).count() + + def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): + """ + Return a new DStream by applying `groupByKey` over a sliding window. + Similar to `DStream.groupByKey()`, but applies it over a sliding window. + + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: Number of partitions of each RDD in the new DStream. + """ + ls = self.mapValues(lambda x: [x]) + grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):], + windowDuration, slideDuration, numPartitions) + return grouped.mapValues(ResultIterable) + + def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None, + numPartitions=None, filterFunc=None): + """ + Return a new DStream by applying incremental `reduceByKey` over a sliding window. + + The reduced value of over a new window is calculated using the old window's reduce value : + 1. reduce the new values that entered the window (e.g., adding new counts) + 2. "inverse reduce" the old values that left the window (e.g., subtracting old counts) + + `invFunc` can be None, then it will reduce all the RDDs in window, could be slower + than having `invFunc`. + + @param reduceFunc: associative reduce function + @param invReduceFunc: inverse function of `reduceFunc` + @param windowDuration: width of the window; must be a multiple of this DStream's + batching interval + @param slideDuration: sliding interval of the window (i.e., the interval after which + the new DStream will generate RDDs); must be a multiple of this + DStream's batching interval + @param numPartitions: number of partitions of each RDD in the new DStream. + @param filterFunc: function to filter expired key-value pairs; + only pairs that satisfy the function are retained + set this to null if you do not want to filter + """ + self._validate_window_param(windowDuration, slideDuration) + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + reduced = self.reduceByKey(func, numPartitions) + + def reduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + r = a.union(b).reduceByKey(func, numPartitions) if a else b + if filterFunc: + r = r.filter(filterFunc) + return r + + def invReduceFunc(t, a, b): + b = b.reduceByKey(func, numPartitions) + joined = a.leftOuterJoin(b, numPartitions) + return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer) + if invReduceFunc: + jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer) + else: + jinvReduceFunc = None + if slideDuration is None: + slideDuration = self._slideDuration + dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), + jreduceFunc, jinvReduceFunc, + self._ssc._jduration(windowDuration), + self._ssc._jduration(slideDuration)) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + def updateStateByKey(self, updateFunc, numPartitions=None): + """ + Return a new "state" DStream where the state for each key is updated by applying + the given function on the previous state of the key and the new values of the key. + + @param updateFunc: State update function. If this function returns None, then + corresponding state key-value pair will be eliminated. + """ + if numPartitions is None: + numPartitions = self._sc.defaultParallelism + + def reduceFunc(t, a, b): + if a is None: + g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None)) + else: + g = a.cogroup(b, numPartitions) + g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None)) + state = g.mapValues(lambda (vs, s): updateFunc(vs, s)) + return state.filter(lambda (k, v): v is not None) + + jreduceFunc = TransformFunction(self._sc, reduceFunc, + self._sc.serializer, self._jrdd_deserializer) + dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer) + + +class TransformedDStream(DStream): + """ + TransformedDStream is an DStream generated by an Python function + transforming each RDD of an DStream to another RDDs. + + Multiple continuous transformations of DStream can be combined into + one transformation. + """ + def __init__(self, prev, func): + self._ssc = prev._ssc + self._sc = self._ssc._sc + self._jrdd_deserializer = self._sc.serializer + self.is_cached = False + self.is_checkpointed = False + self._jdstream_val = None + + if (isinstance(prev, TransformedDStream) and + not prev.is_cached and not prev.is_checkpointed): + prev_func = prev.func + self.func = lambda t, rdd: func(t, prev_func(t, rdd)) + self.prev = prev.prev + else: + self.prev = prev + self.func = func + + @property + def _jdstream(self): + if self._jdstream_val is not None: + return self._jdstream_val + + jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer) + dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc) + self._jdstream_val = dstream.asJavaDStream() + return self._jdstream_val diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py new file mode 100644 index 0000000000000..a8d876d0fa3b3 --- /dev/null +++ b/python/pyspark/streaming/tests.py @@ -0,0 +1,545 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from itertools import chain +import time +import operator +import unittest +import tempfile + +from pyspark.context import SparkConf, SparkContext, RDD +from pyspark.streaming.context import StreamingContext + + +class PySparkStreamingTestCase(unittest.TestCase): + + timeout = 10 # seconds + duration = 1 + + def setUp(self): + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.default.parallelism", 1) + self.sc = SparkContext(appName=class_name, conf=conf) + self.sc.setCheckpointDir("/tmp") + # TODO: decrease duration to speed up tests + self.ssc = StreamingContext(self.sc, self.duration) + + def tearDown(self): + self.ssc.stop() + + def wait_for(self, result, n): + start_time = time.time() + while len(result) < n and time.time() - start_time < self.timeout: + time.sleep(0.01) + if len(result) < n: + print "timeout after", self.timeout + + def _take(self, dstream, n): + """ + Return the first `n` elements in the stream (will start and stop). + """ + results = [] + + def take(_, rdd): + if rdd and len(results) < n: + results.extend(rdd.take(n - len(results))) + + dstream.foreachRDD(take) + + self.ssc.start() + self.wait_for(results, n) + return results + + def _collect(self, dstream, n, block=True): + """ + Collect each RDDs into the returned list. + + :return: list, which will have the collected items. + """ + result = [] + + def get_output(_, rdd): + if rdd and len(result) < n: + r = rdd.collect() + if r: + result.append(r) + + dstream.foreachRDD(get_output) + + if not block: + return result + + self.ssc.start() + self.wait_for(result, n) + return result + + def _test_func(self, input, func, expected, sort=False, input2=None): + """ + @param input: dataset for the test. This should be list of lists. + @param func: wrapped function. This function should return PythonDStream object. + @param expected: expected output for this testcase. + """ + if not isinstance(input[0], RDD): + input = [self.sc.parallelize(d, 1) for d in input] + input_stream = self.ssc.queueStream(input) + if input2 and not isinstance(input2[0], RDD): + input2 = [self.sc.parallelize(d, 1) for d in input2] + input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None + + # Apply test function to stream. + if input2: + stream = func(input_stream, input_stream2) + else: + stream = func(input_stream) + + result = self._collect(stream, len(expected)) + if sort: + self._sort_result_based_on_key(result) + self._sort_result_based_on_key(expected) + self.assertEqual(expected, result) + + def _sort_result_based_on_key(self, outputs): + """Sort the list based on first value.""" + for output in outputs: + output.sort(key=lambda x: x[0]) + + +class BasicOperationTests(PySparkStreamingTestCase): + + def test_map(self): + """Basic operation test for DStream.map.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.map(str) + expected = map(lambda x: map(str, x), input) + self._test_func(input, func, expected) + + def test_flatMap(self): + """Basic operation test for DStream.faltMap.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.flatMap(lambda x: (x, x * 2)) + expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))), + input) + self._test_func(input, func, expected) + + def test_filter(self): + """Basic operation test for DStream.filter.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.filter(lambda x: x % 2 == 0) + expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input) + self._test_func(input, func, expected) + + def test_count(self): + """Basic operation test for DStream.count.""" + input = [range(5), range(10), range(20)] + + def func(dstream): + return dstream.count() + expected = map(lambda x: [len(x)], input) + self._test_func(input, func, expected) + + def test_reduce(self): + """Basic operation test for DStream.reduce.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + + def func(dstream): + return dstream.reduce(operator.add) + expected = map(lambda x: [reduce(operator.add, x)], input) + self._test_func(input, func, expected) + + def test_reduceByKey(self): + """Basic operation test for DStream.reduceByKey.""" + input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)], + [("", 1), ("", 1), ("", 1), ("", 1)], + [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]] + + def func(dstream): + return dstream.reduceByKey(operator.add) + expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]] + self._test_func(input, func, expected, sort=True) + + def test_mapValues(self): + """Basic operation test for DStream.mapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 2), (3, 3)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.mapValues(lambda x: x + 10) + expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], + [("", 14), (1, 11), (2, 12), (3, 13)], + [(1, 11), (2, 11), (3, 11), (4, 11)]] + self._test_func(input, func, expected, sort=True) + + def test_flatMapValues(self): + """Basic operation test for DStream.flatMapValues.""" + input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 1), (3, 1)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] + + def func(dstream): + return dstream.flatMapValues(lambda x: (x, x + 10)) + expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), + ("c", 1), ("c", 11), ("d", 1), ("d", 11)], + [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] + self._test_func(input, func, expected) + + def test_glom(self): + """Basic operation test for DStream.glom.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.glom() + expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] + self._test_func(rdds, func, expected) + + def test_mapPartitions(self): + """Basic operation test for DStream.mapPartitions.""" + input = [range(1, 5), range(5, 9), range(9, 13)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + def f(iterator): + yield sum(iterator) + return dstream.mapPartitions(f) + expected = [[3, 7], [11, 15], [19, 23]] + self._test_func(rdds, func, expected) + + def test_countByValue(self): + """Basic operation test for DStream.countByValue.""" + input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]] + + def func(dstream): + return dstream.countByValue() + expected = [[4], [4], [3]] + self._test_func(input, func, expected) + + def test_groupByKey(self): + """Basic operation test for DStream.groupByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + return dstream.groupByKey().mapValues(list) + + expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])], + [(1, [1, 1, 1]), (2, [1, 1]), (3, [1])], + [("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]] + self._test_func(input, func, expected, sort=True) + + def test_combineByKey(self): + """Basic operation test for DStream.combineByKey.""" + input = [[(1, 1), (2, 1), (3, 1), (4, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]] + + def func(dstream): + def add(a, b): + return a + str(b) + return dstream.combineByKey(str, add, add) + expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")], + [(1, "111"), (2, "11"), (3, "1")], + [("a", "11"), ("b", "1"), ("", "111")]] + self._test_func(input, func, expected, sort=True) + + def test_repartition(self): + input = [range(1, 5), range(5, 9)] + rdds = [self.sc.parallelize(r, 2) for r in input] + + def func(dstream): + return dstream.repartition(1).glom() + expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]] + self._test_func(rdds, func, expected) + + def test_union(self): + input1 = [range(3), range(5), range(6)] + input2 = [range(3, 6), range(5, 6)] + + def func(d1, d2): + return d1.union(d2) + + expected = [range(6), range(6), range(6)] + self._test_func(input1, func, expected, input2=input2) + + def test_cogroup(self): + input = [[(1, 1), (2, 1), (3, 1)], + [(1, 1), (1, 1), (1, 1), (2, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]] + input2 = [[(1, 2)], + [(4, 1)], + [("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]] + + def func(d1, d2): + return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs))) + + expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))], + [(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))], + [("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]] + self._test_func(input, func, expected, sort=True, input2=input2) + + def test_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.join(b) + + expected = [[('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_left_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.leftOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3))]] + self._test_func(input, func, expected, True, input2) + + def test_right_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.rightOuterJoin(b) + + expected = [[('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_full_outer_join(self): + input = [[('a', 1), ('b', 2)]] + input2 = [[('b', 3), ('c', 4)]] + + def func(a, b): + return a.fullOuterJoin(b) + + expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]] + self._test_func(input, func, expected, True, input2) + + def test_update_state_by_key(self): + + def updater(vs, s): + if not s: + s = [] + s.extend(vs) + return s + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + + +class WindowFunctionTests(PySparkStreamingTestCase): + + timeout = 20 + + def test_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.window(3, 1).count() + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window(self): + input = [range(1), range(2), range(3), range(4), range(5)] + + def func(dstream): + return dstream.countByWindow(3, 1) + + expected = [[1], [3], [6], [9], [12], [9], [5]] + self._test_func(input, func, expected) + + def test_count_by_window_large(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByWindow(5, 1) + + expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]] + self._test_func(input, func, expected) + + def test_count_by_value_and_window(self): + input = [range(1), range(2), range(3), range(4), range(5), range(6)] + + def func(dstream): + return dstream.countByValueAndWindow(5, 1) + + expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + self._test_func(input, func, expected) + + def test_group_by_key_and_window(self): + input = [[('a', i)] for i in range(5)] + + def func(dstream): + return dstream.groupByKeyAndWindow(3, 1).mapValues(list) + + expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])], + [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] + self._test_func(input, func, expected) + + def test_reduce_by_invalid_window(self): + input1 = [range(3), range(5), range(1), range(6)] + d1 = self.ssc.queueStream(input1) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1)) + self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1)) + + +class StreamingContextTests(PySparkStreamingTestCase): + + duration = 0.1 + + def _add_input_stream(self): + inputs = map(lambda x: range(1, x), range(101)) + stream = self.ssc.queueStream(inputs) + self._collect(stream, 1, block=False) + + def test_stop_only_streaming_context(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop(False) + self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) + + def test_stop_multiple_times(self): + self._add_input_stream() + self.ssc.start() + self.ssc.stop() + self.ssc.stop() + + def test_queue_stream(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = self._collect(dstream, 3) + self.assertEqual(input, result) + + def test_text_file_stream(self): + d = tempfile.mkdtemp() + self.ssc = StreamingContext(self.sc, self.duration) + dstream2 = self.ssc.textFileStream(d).map(int) + result = self._collect(dstream2, 2, block=False) + self.ssc.start() + for name in ('a', 'b'): + time.sleep(1) + with open(os.path.join(d, name), "w") as f: + f.writelines(["%d\n" % i for i in range(10)]) + self.wait_for(result, 2) + self.assertEqual([range(10), range(10)], result) + + def test_union(self): + input = [range(i + 1) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.queueStream(input) + dstream3 = self.ssc.union(dstream, dstream2) + result = self._collect(dstream3, 3) + expected = [i * 2 for i in input] + self.assertEqual(expected, result) + + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], self._take(dstream, 3)) + + +class CheckpointTests(PySparkStreamingTestCase): + + def setUp(self): + pass + + def test_get_or_create(self): + inputd = tempfile.mkdtemp() + outputd = tempfile.mkdtemp() + "/" + + def updater(vs, s): + return sum(vs, s or 0) + + def setup(): + conf = SparkConf().set("spark.default.parallelism", 1) + sc = SparkContext(conf=conf) + ssc = StreamingContext(sc, 0.5) + dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1)) + wc = dstream.updateStateByKey(updater) + wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test") + wc.checkpoint(.5) + return ssc + + cpd = tempfile.mkdtemp("test_streaming_cps") + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + + def check_output(n): + while not os.listdir(outputd): + time.sleep(0.1) + time.sleep(1) # make sure mtime is larger than the previous one + with open(os.path.join(inputd, str(n)), 'w') as f: + f.writelines(["%d\n" % i for i in range(10)]) + + while True: + p = os.path.join(outputd, max(os.listdir(outputd))) + if '_SUCCESS' not in os.listdir(p): + # not finished + time.sleep(0.01) + continue + ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(",")) + d = ordd.values().map(int).collect() + if not d: + time.sleep(0.01) + continue + self.assertEqual(10, len(d)) + s = set(d) + self.assertEqual(1, len(s)) + m = s.pop() + if n > m: + continue + self.assertEqual(n, m) + break + + check_output(1) + check_output(2) + ssc.stop(True, True) + + time.sleep(1) + self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup) + ssc.start() + check_output(3) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py new file mode 100644 index 0000000000000..86ee5aa04f252 --- /dev/null +++ b/python/pyspark/streaming/util.py @@ -0,0 +1,128 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +from datetime import datetime +import traceback + +from pyspark import SparkContext, RDD + + +class TransformFunction(object): + """ + This class wraps a function RDD[X] -> RDD[Y] that was passed to + DStream.transform(), allowing it to be called from Java via Py4J's + callback server. + + Java calls this function with a sequence of JavaRDDs and this function + returns a single JavaRDD pointer back to Java. + """ + _emptyRDD = None + + def __init__(self, ctx, func, *deserializers): + self.ctx = ctx + self.func = func + self.deserializers = deserializers + + def call(self, milliseconds, jrdds): + try: + if self.ctx is None: + self.ctx = SparkContext._active_spark_context + if not self.ctx or not self.ctx._jsc: + # stopped + return + + # extend deserializers with the first one + sers = self.deserializers + if len(sers) < len(jrdds): + sers += (sers[0],) * (len(jrdds) - len(sers)) + + rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None + for jrdd, ser in zip(jrdds, sers)] + t = datetime.fromtimestamp(milliseconds / 1000.0) + r = self.func(t, *rdds) + if r: + return r._jrdd + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunction(%s)" % self.func + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction'] + + +class TransformFunctionSerializer(object): + """ + This class implements a serializer for PythonTransformFunction Java + objects. + + This is necessary because the Java PythonTransformFunction objects are + actually Py4J references to Python objects and thus are not directly + serializable. When Java needs to serialize a PythonTransformFunction, + it uses this class to invoke Python, which returns the serialized function + as a byte array. + """ + def __init__(self, ctx, serializer, gateway=None): + self.ctx = ctx + self.serializer = serializer + self.gateway = gateway or self.ctx._gateway + self.gateway.jvm.PythonDStream.registerSerializer(self) + + def dumps(self, id): + try: + func = self.gateway.gateway_property.pool[id] + return bytearray(self.serializer.dumps((func.func, func.deserializers))) + except Exception: + traceback.print_exc() + + def loads(self, bytes): + try: + f, deserializers = self.serializer.loads(str(bytes)) + return TransformFunction(self.ctx, f, *deserializers) + except Exception: + traceback.print_exc() + + def __repr__(self): + return "TransformFunctionSerializer(%s)" % self.serializer + + class Java: + implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer'] + + +def rddToFileName(prefix, suffix, timestamp): + """ + Return string prefix-time(.suffix) + + >>> rddToFileName("spark", None, 12345678910) + 'spark-12345678910' + >>> rddToFileName("spark", "tmp", 12345678910) + 'spark-12345678910.tmp' + """ + if isinstance(timestamp, datetime): + seconds = time.mktime(timestamp.timetuple()) + timestamp = long(seconds * 1000) + timestamp.microsecond / 1000 + if suffix is None: + return prefix + "-" + str(timestamp) + else: + return prefix + "-" + str(timestamp) + "." + suffix + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6fb6bc998c752..32645778c2b8f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -31,10 +31,15 @@ import time import zipfile import random -from platform import python_implementation +import threading +import hashlib if sys.version_info[:2] <= (2, 6): - import unittest2 as unittest + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) else: import unittest @@ -43,9 +48,10 @@ from pyspark.context import SparkContext from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer + CloudPickleSerializer, CompressedSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType from pyspark import shuffle _have_scipy = False @@ -67,10 +73,10 @@ SPARK_HOME = os.environ["SPARK_HOME"] -class TestMerger(unittest.TestCase): +class MergerTests(unittest.TestCase): def setUp(self): - self.N = 1 << 16 + self.N = 1 << 14 self.l = [i for i in xrange(self.N)] self.data = zip(self.l, self.l) self.agg = Aggregator(lambda x: [x], @@ -115,7 +121,7 @@ def test_medium_dataset(self): sum(xrange(self.N)) * 3) def test_huge_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 10, partitions=3) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), @@ -123,7 +129,7 @@ def test_huge_dataset(self): m._cleanup() -class TestSorter(unittest.TestCase): +class SorterTests(unittest.TestCase): def test_in_memory_sort(self): l = range(1024) random.shuffle(l) @@ -231,29 +237,49 @@ def foo(): self.assertTrue("exit" in foo.func_code.co_names) ser.dumps(foo) + def test_compressed_serializer(self): + ser = CompressedSerializer(PickleSerializer()) + from StringIO import StringIO + io = StringIO() + ser.dump_stream(["abc", u"123", range(5)], io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io))) + ser.dump_stream(range(1000), io) + io.seek(0) + self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io))) + class PySparkTestCase(unittest.TestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ - self.sc = SparkContext('local[4]', class_name, batchSize=2) + self.sc = SparkContext('local[4]', class_name) def tearDown(self): self.sc.stop() sys.path = self._old_sys_path -class TestCheckpoint(PySparkTestCase): +class ReusedPySparkTestCase(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.sc = SparkContext('local[4]', cls.__name__) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + + +class CheckpointTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.checkpointDir = tempfile.NamedTemporaryFile(delete=False) os.unlink(self.checkpointDir.name) self.sc.setCheckpointDir(self.checkpointDir.name) def tearDown(self): - PySparkTestCase.tearDown(self) shutil.rmtree(self.checkpointDir.name) def test_basic_checkpointing(self): @@ -288,7 +314,7 @@ def test_checkpoint_and_restore(self): self.assertEquals([1, 2, 3, 4], recovered.collect()) -class TestAddFile(PySparkTestCase): +class AddFileTests(PySparkTestCase): def test_add_py_file(self): # To ensure that we're actually testing addPyFile's effects, check that @@ -354,7 +380,7 @@ def func(x): self.assertEqual(["My Server"], self.sc.parallelize(range(1)).map(func).collect()) -class TestRDDFunctions(PySparkTestCase): +class RDDTests(ReusedPySparkTestCase): def test_id(self): rdd = self.sc.parallelize(range(10)) @@ -365,12 +391,6 @@ def test_id(self): self.assertEqual(id + 1, id2) self.assertEqual(id2, rdd2.id()) - def test_failed_sparkcontext_creation(self): - # Regression test for SPARK-1550 - self.sc.stop() - self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) - self.sc = SparkContext("local") - def test_save_as_textfile_with_unicode(self): # Regression test for SPARK-970 x = u"\u00A1Hola, mundo!" @@ -426,7 +446,13 @@ def test_deleting_input_files(self): os.unlink(tempFile.name) self.assertRaises(Exception, lambda: filtered_data.count()) - def testAggregateByKey(self): + def test_sampling_default_seed(self): + # Test for SPARK-3995 (default seed setting) + data = self.sc.parallelize(range(1000), 1) + subset = data.takeSample(False, 10) + self.assertEqual(len(subset), 10) + + def test_aggregate_by_key(self): data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2) def seqOp(x, y): @@ -464,6 +490,32 @@ def test_large_broadcast(self): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEquals(N, m) + def test_multiple_broadcasts(self): + N = 1 << 21 + b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM + r = range(1 << 15) + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + + random.shuffle(r) + s = str(r) + checksum = hashlib.md5(s).hexdigest() + b2 = self.sc.broadcast(s) + r = list(set(self.sc.parallelize(range(10), 10).map( + lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect())) + self.assertEqual(1, len(r)) + size, csum = r[0] + self.assertEqual(N, size) + self.assertEqual(checksum, csum) + def test_large_closure(self): N = 1000000 data = [float(i) for i in xrange(N)] @@ -635,14 +687,32 @@ def test_distinct(self): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_sort_on_empty_rdd(self): + self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect()) + + def test_sample(self): + rdd = self.sc.parallelize(range(0, 100), 4) + wo = rdd.sample(False, 0.1, 2).collect() + wo_dup = rdd.sample(False, 0.1, 2).collect() + self.assertSetEqual(set(wo), set(wo_dup)) + wr = rdd.sample(True, 0.2, 5).collect() + wr_dup = rdd.sample(True, 0.2, 5).collect() + self.assertSetEqual(set(wr), set(wr_dup)) + wo_s10 = rdd.sample(False, 0.3, 10).collect() + wo_s20 = rdd.sample(False, 0.3, 20).collect() + self.assertNotEqual(set(wo_s10), set(wo_s20)) + wr_s11 = rdd.sample(True, 0.4, 11).collect() + wr_s21 = rdd.sample(True, 0.4, 21).collect() + self.assertNotEqual(set(wr_s11), set(wr_s21)) -class TestProfiler(PySparkTestCase): + +class ProfilerTests(PySparkTestCase): def setUp(self): self._old_sys_path = list(sys.path) class_name = self.__class__.__name__ conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) + self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): @@ -666,10 +736,66 @@ def heavy_foo(x): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) -class TestSQL(PySparkTestCase): +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + +class SQLTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name, ignore_errors=True) def setUp(self): - PySparkTestCase.setUp(self) self.sqlCtx = SQLContext(self.sc) def test_udf(self): @@ -677,6 +803,22 @@ def test_udf(self): [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType()) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + + def test_udf_with_array_type(self): + d = [Row(l=range(3), d={"key": range(5)})] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) + self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) + [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() + self.assertEqual(range(3), l1) + self.assertEqual(1, l2) + def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) @@ -753,28 +895,82 @@ def test_serialize_nested_array_and_map(self): self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) - -class TestIO(PySparkTestCase): - - def test_stdout_redirection(self): - import subprocess - - def func(x): - subprocess.check_call('ls', shell=True) - self.sc.parallelize([1]).foreach(func) - - -class TestInputFormat(PySparkTestCase): - - def setUp(self): - PySparkTestCase.setUp(self) - self.tempdir = tempfile.NamedTemporaryFile(delete=False) - os.unlink(self.tempdir.name) - self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc) - - def tearDown(self): - PySparkTestCase.tearDown(self) - shutil.rmtree(self.tempdir.name) + def test_infer_schema(self): + d = [Row(l=[], d={}), + Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], srdd.map(lambda r: r.l).first()) + self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) + srdd.registerTempTable("test") + result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + + srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(srdd.schema(), srdd2.schema()) + self.assertEqual({}, srdd2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) + srdd2.registerTempTable("test2") + result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") + self.assertEqual(1, result.first()[0]) + + def test_convert_row_to_dict(self): + row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) + self.assertEqual(1, row.asDict()['l'][0].a) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + srdd.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").first() + self.assertEqual(1, row.asDict()["l"][0].a) + self.assertEqual(1.0, row.asDict()['d']['key'].c) + + def test_infer_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + srdd.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + srdd = self.sqlCtx.applySchema(rdd, schema) + point = srdd.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + srdd0.saveAsParquetFile(output_dir) + srdd1 = self.sqlCtx.parquetFile(output_dir) + point = srdd1.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + +class InputFormatTests(ReusedPySparkTestCase): + + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + cls.sc._jvm.WriteInputFormatTestDataGenerator.generateData(cls.tempdir.name, cls.sc._jsc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) def test_sequencefiles(self): basepath = self.tempdir.name @@ -858,16 +1054,19 @@ def test_sequencefiles(self): clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable").collect()) - ec = (u'1', - {u'__class__': u'org.apache.spark.api.python.TestWritable', - u'double': 54.0, u'int': 123, u'str': u'test1'}) - self.assertEqual(clazz[0], ec) + cname = u'org.apache.spark.api.python.TestWritable' + ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}), + (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}), + (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}), + (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}), + (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})] + self.assertEqual(clazz, ec) unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", "org.apache.hadoop.io.Text", "org.apache.spark.api.python.TestWritable", - batchSize=1).collect()) - self.assertEqual(unbatched_clazz[0], ec) + ).collect()) + self.assertEqual(unbatched_clazz, ec) def test_oldhadoop(self): basepath = self.tempdir.name @@ -953,16 +1152,33 @@ def test_converters(self): (u'\x03', [2.0])] self.assertEqual(maps, em) + def test_binary_files(self): + path = os.path.join(self.tempdir.name, "binaryfiles") + os.mkdir(path) + data = "short binary data" + with open(os.path.join(path, "part-0000"), 'w') as f: + f.write(data) + [(p, d)] = self.sc.binaryFiles(path).collect() + self.assertTrue(p.endswith("part-0000")) + self.assertEqual(d, data) + + def test_binary_records(self): + path = os.path.join(self.tempdir.name, "binaryrecords") + os.mkdir(path) + with open(os.path.join(path, "part-0000"), 'w') as f: + for i in range(100): + f.write('%04d' % i) + result = self.sc.binaryRecords(path, 4).map(int).collect() + self.assertEqual(range(100), result) -class TestOutputFormat(PySparkTestCase): + +class OutputFormatTests(ReusedPySparkTestCase): def setUp(self): - PySparkTestCase.setUp(self) self.tempdir = tempfile.NamedTemporaryFile(delete=False) os.unlink(self.tempdir.name) def tearDown(self): - PySparkTestCase.tearDown(self) shutil.rmtree(self.tempdir.name, ignore_errors=True) def test_sequencefiles(self): @@ -1189,51 +1405,6 @@ def test_reserialization(self): result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect()) self.assertEqual(result5, data) - def test_unbatched_save_and_read(self): - basepath = self.tempdir.name - ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] - self.sc.parallelize(ei, len(ei)).saveAsSequenceFile( - basepath + "/unbatched/") - - unbatched_sequence = sorted(self.sc.sequenceFile( - basepath + "/unbatched/", - batchSize=1).collect()) - self.assertEqual(unbatched_sequence, ei) - - unbatched_hadoopFile = sorted(self.sc.hadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopFile, ei) - - unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile( - basepath + "/unbatched/", - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopFile, ei) - - oldconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_hadoopRDD = sorted(self.sc.hadoopRDD( - "org.apache.hadoop.mapred.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=oldconf, - batchSize=1).collect()) - self.assertEqual(unbatched_hadoopRDD, ei) - - newconf = {"mapred.input.dir": basepath + "/unbatched/"} - unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD( - "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", - "org.apache.hadoop.io.IntWritable", - "org.apache.hadoop.io.Text", - conf=newconf, - batchSize=1).collect()) - self.assertEqual(unbatched_newAPIHadoopRDD, ei) - def test_malformed_RDD(self): basepath = self.tempdir.name # non-batch-serialized RDD[[(K, V)]] should be rejected @@ -1243,8 +1414,7 @@ def test_malformed_RDD(self): basepath + "/malformed/sequence")) -class TestDaemon(unittest.TestCase): - +class DaemonTests(unittest.TestCase): def connect(self, port): from socket import socket, AF_INET, SOCK_STREAM sock = socket(AF_INET, SOCK_STREAM) @@ -1290,7 +1460,7 @@ def test_termination_sigterm(self): self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM)) -class TestWorker(PySparkTestCase): +class WorkerTests(PySparkTestCase): def test_cancel_task(self): temp = tempfile.NamedTemporaryFile(delete=True) @@ -1342,11 +1512,6 @@ def run(): rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) - def test_fd_leak(self): - N = 1100 # fd limit is 1024 by default - rdd = self.sc.parallelize(range(N), N) - self.assertEquals(N, rdd.count()) - def test_after_exception(self): def raise_exception(_): raise Exception() @@ -1378,8 +1543,25 @@ def test_accumulator_when_reuse_worker(self): self.assertEqual(sum(range(100)), acc2.value) self.assertEqual(sum(range(100)), acc1.value) + def test_reuse_worker_after_take(self): + rdd = self.sc.parallelize(range(100000), 1) + self.assertEqual(0, rdd.first()) + + def count(): + try: + rdd.count() + except Exception: + pass + + t = threading.Thread(target=count) + t.daemon = True + t.start() + t.join(5) + self.assertTrue(not t.isAlive()) + self.assertEqual(100000, rdd.count()) -class TestSparkSubmit(unittest.TestCase): + +class SparkSubmitTests(unittest.TestCase): def setUp(self): self.programDir = tempfile.mkdtemp() @@ -1492,6 +1674,8 @@ def test_single_script_on_cluster(self): |sc = SparkContext() |print sc.parallelize([1, 2, 3]).map(foo).collect() """) + # this will fail if you have different spark.executor.memory + # in conf/spark-defaults.conf proc = subprocess.Popen( [self.sparkSubmit, "--master", "local-cluster[1,1,512]", script], stdout=subprocess.PIPE) @@ -1500,7 +1684,11 @@ def test_single_script_on_cluster(self): self.assertIn("[2, 4, 6]", out) -class ContextStopTests(unittest.TestCase): +class ContextTests(unittest.TestCase): + + def test_failed_sparkcontext_creation(self): + # Regression test for SPARK-1550 + self.assertRaises(Exception, lambda: SparkContext("an-invalid-master-name")) def test_stop(self): sc = SparkContext() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8257dddfee1c3..7e5343c973dc5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,8 +30,7 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ - write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - CompressedSerializer + write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer from pyspark import shuffle pickleSer = PickleSerializer() @@ -57,7 +56,7 @@ def main(infile, outfile): boot_time = time.time() split_index = read_int(infile) if split_index == -1: # for unit tests - return + exit(-1) # initialize global state shuffle.MemoryBytesSpilled = 0 @@ -78,12 +77,11 @@ def main(infile, outfile): # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile) - ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) if bid >= 0: - value = ser._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) + path = utf8_deserializer.loads(infile) + _broadcastRegistry[bid] = Broadcast(path=path) else: bid = - bid - 1 _broadcastRegistry.pop(bid) @@ -111,7 +109,6 @@ def process(): try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) - outfile.flush() except IOError: # JVM close the socket pass @@ -131,6 +128,14 @@ def process(): for (aid, accum) in _accumulatorRegistry.items(): pickleSer._write_with_length((aid, accum._value), outfile) + # check end of stream + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + else: + # write a different value to tell JVM to not reuse this worker + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + exit(-1) + if __name__ == '__main__': # Read a local port to connect to from stdin diff --git a/python/run-tests b/python/run-tests index a7ec270c7da21..9ee19ed6e6b26 100755 --- a/python/run-tests +++ b/python/run-tests @@ -25,22 +25,23 @@ FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" cd "$FWDIR/python" FAILED=0 +LOG_FILE=unit-tests.log -rm -f unit-tests.log +rm -f $LOG_FILE # Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL rm -rf metastore warehouse function run_test() { - echo "Running test: $1" + echo "Running test: $1" | tee -a $LOG_FILE - SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE FAILED=$((PIPESTATUS[0]||$FAILED)) # Fail and exit on the first test failure. if [[ $FAILED != 0 ]]; then - cat unit-tests.log | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. + cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. echo -en "\033[31m" # Red echo "Had test failures; see logs." echo -en "\033[0m" # No color @@ -48,7 +49,45 @@ function run_test() { fi } -echo "Running PySpark tests. Output is in python/unit-tests.log." +function run_core_tests() { + echo "Run core tests ..." + run_test "pyspark/rdd.py" + run_test "pyspark/context.py" + run_test "pyspark/conf.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" + PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" + run_test "pyspark/serializers.py" + run_test "pyspark/shuffle.py" + run_test "pyspark/tests.py" +} + +function run_sql_tests() { + echo "Run sql tests ..." + run_test "pyspark/sql.py" +} + +function run_mllib_tests() { + echo "Run mllib tests ..." + run_test "pyspark/mllib/classification.py" + run_test "pyspark/mllib/clustering.py" + run_test "pyspark/mllib/feature.py" + run_test "pyspark/mllib/linalg.py" + run_test "pyspark/mllib/rand.py" + run_test "pyspark/mllib/recommendation.py" + run_test "pyspark/mllib/regression.py" + run_test "pyspark/mllib/stat.py" + run_test "pyspark/mllib/tree.py" + run_test "pyspark/mllib/util.py" + run_test "pyspark/mllib/tests.py" +} + +function run_streaming_tests() { + echo "Run streaming tests ..." + run_test "pyspark/streaming/util.py" + run_test "pyspark/streaming/tests.py" +} + +echo "Running PySpark tests. Output is in python/$LOG_FILE." export PYSPARK_PYTHON="python" @@ -60,29 +99,10 @@ fi echo "Testing with Python version:" $PYSPARK_PYTHON --version -run_test "pyspark/rdd.py" -run_test "pyspark/context.py" -run_test "pyspark/conf.py" -run_test "pyspark/sql.py" -# These tests are included in the module-level docs, and so must -# be handled on a higher level rather than within the python file. -export PYSPARK_DOC_TEST=1 -run_test "pyspark/broadcast.py" -run_test "pyspark/accumulators.py" -run_test "pyspark/serializers.py" -unset PYSPARK_DOC_TEST -run_test "pyspark/shuffle.py" -run_test "pyspark/tests.py" -run_test "pyspark/mllib/classification.py" -run_test "pyspark/mllib/clustering.py" -run_test "pyspark/mllib/linalg.py" -run_test "pyspark/mllib/random.py" -run_test "pyspark/mllib/recommendation.py" -run_test "pyspark/mllib/regression.py" -run_test "pyspark/mllib/stat.py" -run_test "pyspark/mllib/tests.py" -run_test "pyspark/mllib/tree.py" -run_test "pyspark/mllib/util.py" +run_core_tests +run_sql_tests +run_mllib_tests +run_streaming_tests # Try to test with PyPy if [ $(which pypy) ]; then @@ -90,19 +110,9 @@ if [ $(which pypy) ]; then echo "Testing with PyPy version:" $PYSPARK_PYTHON --version - run_test "pyspark/rdd.py" - run_test "pyspark/context.py" - run_test "pyspark/conf.py" - run_test "pyspark/sql.py" - # These tests are included in the module-level docs, and so must - # be handled on a higher level rather than within the python file. - export PYSPARK_DOC_TEST=1 - run_test "pyspark/broadcast.py" - run_test "pyspark/accumulators.py" - run_test "pyspark/serializers.py" - unset PYSPARK_DOC_TEST - run_test "pyspark/shuffle.py" - run_test "pyspark/tests.py" + run_core_tests + run_sql_tests + run_streaming_tests fi if [[ $FAILED == 0 ]]; then diff --git a/repl/pom.xml b/repl/pom.xml index af528c8914335..9b2290429fee5 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.2.0-SNAPSHOT + 1.3.0-SNAPSHOT ../pom.xml @@ -35,9 +35,16 @@ repl /usr/share/spark root + scala-2.10/src/main/scala + scala-2.10/src/test/scala + + ${jline.groupid} + jline + ${jline.version} + org.apache.spark spark-core_${scala.binary.version} @@ -75,11 +82,6 @@ scala-reflect ${scala.version} - - org.scala-lang - jline - ${scala.version} - org.slf4j jul-to-slf4j @@ -122,6 +124,51 @@ + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-scala-sources + generate-sources + + add-source + + + + src/main/scala + ${extra.source.dir} + + + + + add-scala-test-sources + generate-test-sources + + add-test-source + + + + src/test/scala + ${extra.testsource.dir} + + + + + + + + scala-2.11 + + scala-2.11 + + + scala-2.11/src/main/scala + scala-2.11/src/test/scala + + + diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/Main.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala similarity index 100% rename from repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala similarity index 95% rename from repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 7667a9c11979e..da4286c5e4874 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -121,11 +121,14 @@ trait SparkILoopInit { def initializeSpark() { intp.beQuietDuring { command(""" - @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext(); + @transient val sc = { + val _sc = org.apache.spark.repl.Main.interp.createSparkContext() + println("Spark context available as sc.") + _sc + } """) command("import org.apache.spark.SparkContext._") } - echo("Spark context available as sc.") } // code to be executed only after the interpreter is initialized diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala new file mode 100644 index 0000000000000..646c68e60c2e9 --- /dev/null +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -0,0 +1,1445 @@ +// scalastyle:off + +/* NSC -- new Scala compiler + * Copyright 2005-2013 LAMP/EPFL + * @author Martin Odersky + */ + +package org.apache.spark.repl + +import java.io.File + +import scala.tools.nsc._ +import scala.tools.nsc.backend.JavaPlatform +import scala.tools.nsc.interpreter._ + +import Predef.{ println => _, _ } +import scala.tools.nsc.util.{MergedClassPath, stringFromWriter, ScalaClassLoader, stackTraceString} +import scala.reflect.internal.util._ +import java.net.URL +import scala.sys.BooleanProp +import io.{AbstractFile, PlainFile, VirtualDirectory} + +import reporters._ +import symtab.Flags +import scala.reflect.internal.Names +import scala.tools.util.PathResolver +import ScalaClassLoader.URLClassLoader +import scala.tools.nsc.util.Exceptional.unwrap +import scala.collection.{ mutable, immutable } +import scala.util.control.Exception.{ ultimately } +import SparkIMain._ +import java.util.concurrent.Future +import typechecker.Analyzer +import scala.language.implicitConversions +import scala.reflect.runtime.{ universe => ru } +import scala.reflect.{ ClassTag, classTag } +import scala.tools.reflect.StdRuntimeTags._ +import scala.util.control.ControlThrowable + +import org.apache.spark.{Logging, HttpServer, SecurityManager, SparkConf} +import org.apache.spark.util.Utils + +// /** directory to save .class files to */ +// private class ReplVirtualDirectory(out: JPrintWriter) extends VirtualDirectory("((memory))", None) { +// private def pp(root: AbstractFile, indentLevel: Int) { +// val spaces = " " * indentLevel +// out.println(spaces + root.name) +// if (root.isDirectory) +// root.toList sortBy (_.name) foreach (x => pp(x, indentLevel + 1)) +// } +// // print the contents hierarchically +// def show() = pp(this, 0) +// } + + /** An interpreter for Scala code. + * + * The main public entry points are compile(), interpret(), and bind(). + * The compile() method loads a complete Scala file. The interpret() method + * executes one line of Scala code at the request of the user. The bind() + * method binds an object to a variable that can then be used by later + * interpreted code. + * + * The overall approach is based on compiling the requested code and then + * using a Java classloader and Java reflection to run the code + * and access its results. + * + * In more detail, a single compiler instance is used + * to accumulate all successfully compiled or interpreted Scala code. To + * "interpret" a line of code, the compiler generates a fresh object that + * includes the line of code and which has public member(s) to export + * all variables defined by that code. To extract the result of an + * interpreted line to show the user, a second "result object" is created + * which imports the variables exported by the above object and then + * exports members called "$eval" and "$print". To accomodate user expressions + * that read from variables or methods defined in previous statements, "import" + * statements are used. + * + * This interpreter shares the strengths and weaknesses of using the + * full compiler-to-Java. The main strength is that interpreted code + * behaves exactly as does compiled code, including running at full speed. + * The main weakness is that redefining classes and methods is not handled + * properly, because rebinding at the Java level is technically difficult. + * + * @author Moez A. Abdel-Gawad + * @author Lex Spoon + */ + class SparkIMain( + initialSettings: Settings, + val out: JPrintWriter, + propagateExceptions: Boolean = false) + extends SparkImports with Logging { imain => + + val conf = new SparkConf() + + val SPARK_DEBUG_REPL: Boolean = (System.getenv("SPARK_DEBUG_REPL") == "1") + /** Local directory to save .class files too */ + lazy val outputDir = { + val tmp = System.getProperty("java.io.tmpdir") + val rootDir = conf.get("spark.repl.classdir", tmp) + Utils.createTempDir(rootDir) + } + if (SPARK_DEBUG_REPL) { + echo("Output directory: " + outputDir) + } + + val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles + /** Jetty server that will serve our classes to worker nodes */ + val classServerPort = conf.getInt("spark.replClassServer.port", 0) + val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server") + private var currentSettings: Settings = initialSettings + var printResults = true // whether to print result lines + var totalSilence = false // whether to print anything + private var _initializeComplete = false // compiler is initialized + private var _isInitialized: Future[Boolean] = null // set up initialization future + private var bindExceptions = true // whether to bind the lastException variable + private var _executionWrapper = "" // code to be wrapped around all lines + + + // Start the classServer and store its URI in a spark system property + // (which will be passed to executors so that they can connect to it) + classServer.start() + if (SPARK_DEBUG_REPL) { + echo("Class server started, URI = " + classServer.uri) + } + + /** We're going to go to some trouble to initialize the compiler asynchronously. + * It's critical that nothing call into it until it's been initialized or we will + * run into unrecoverable issues, but the perceived repl startup time goes + * through the roof if we wait for it. So we initialize it with a future and + * use a lazy val to ensure that any attempt to use the compiler object waits + * on the future. + */ + private var _classLoader: AbstractFileClassLoader = null // active classloader + private val _compiler: Global = newCompiler(settings, reporter) // our private compiler + + private trait ExposeAddUrl extends URLClassLoader { def addNewUrl(url: URL) = this.addURL(url) } + private var _runtimeClassLoader: URLClassLoader with ExposeAddUrl = null // wrapper exposing addURL + + private val nextReqId = { + var counter = 0 + () => { counter += 1 ; counter } + } + + def compilerClasspath: Seq[URL] = ( + if (isInitializeComplete) global.classPath.asURLs + else new PathResolver(settings).result.asURLs // the compiler's classpath + ) + def settings = currentSettings + def mostRecentLine = prevRequestList match { + case Nil => "" + case req :: _ => req.originalLine + } + // Run the code body with the given boolean settings flipped to true. + def withoutWarnings[T](body: => T): T = beQuietDuring { + val saved = settings.nowarn.value + if (!saved) + settings.nowarn.value = true + + try body + finally if (!saved) settings.nowarn.value = false + } + + /** construct an interpreter that reports to Console */ + def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true)) + def this() = this(new Settings()) + + lazy val repllog: Logger = new Logger { + val out: JPrintWriter = imain.out + val isInfo: Boolean = BooleanProp keyExists "scala.repl.info" + val isDebug: Boolean = BooleanProp keyExists "scala.repl.debug" + val isTrace: Boolean = BooleanProp keyExists "scala.repl.trace" + } + lazy val formatting: Formatting = new Formatting { + val prompt = Properties.shellPromptString + } + lazy val reporter: ConsoleReporter = new SparkIMain.ReplReporter(this) + + import formatting._ + import reporter.{ printMessage, withoutTruncating } + + // This exists mostly because using the reporter too early leads to deadlock. + private def echo(msg: String) { Console println msg } + private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }")) + private def _initialize() = { + try { + // todo. if this crashes, REPL will hang + new _compiler.Run() compileSources _initSources + _initializeComplete = true + true + } + catch AbstractOrMissingHandler() + } + private def tquoted(s: String) = "\"\"\"" + s + "\"\"\"" + + // argument is a thunk to execute after init is done + def initialize(postInitSignal: => Unit) { + synchronized { + if (_isInitialized == null) { + _isInitialized = io.spawn { + try _initialize() + finally postInitSignal + } + } + } + } + def initializeSynchronous(): Unit = { + if (!isInitializeComplete) { + _initialize() + assert(global != null, global) + } + } + def isInitializeComplete = _initializeComplete + + /** the public, go through the future compiler */ + lazy val global: Global = { + if (isInitializeComplete) _compiler + else { + // If init hasn't been called yet you're on your own. + if (_isInitialized == null) { + logWarning("Warning: compiler accessed before init set up. Assuming no postInit code.") + initialize(()) + } + // // blocks until it is ; false means catastrophic failure + if (_isInitialized.get()) _compiler + else null + } + } + @deprecated("Use `global` for access to the compiler instance.", "2.9.0") + lazy val compiler: global.type = global + + import global._ + import definitions.{ScalaPackage, JavaLangPackage, termMember, typeMember} + import rootMirror.{RootClass, getClassIfDefined, getModuleIfDefined, getRequiredModule, getRequiredClass} + + implicit class ReplTypeOps(tp: Type) { + def orElse(other: => Type): Type = if (tp ne NoType) tp else other + def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp) + } + + // TODO: If we try to make naming a lazy val, we run into big time + // scalac unhappiness with what look like cycles. It has not been easy to + // reduce, but name resolution clearly takes different paths. + object naming extends { + val global: imain.global.type = imain.global + } with Naming { + // make sure we don't overwrite their unwisely named res3 etc. + def freshUserTermName(): TermName = { + val name = newTermName(freshUserVarName()) + if (definedNameMap contains name) freshUserTermName() + else name + } + def isUserTermName(name: Name) = isUserVarName("" + name) + def isInternalTermName(name: Name) = isInternalVarName("" + name) + } + import naming._ + + object deconstruct extends { + val global: imain.global.type = imain.global + } with StructuredTypeStrings + + lazy val memberHandlers = new { + val intp: imain.type = imain + } with SparkMemberHandlers + import memberHandlers._ + + /** Temporarily be quiet */ + def beQuietDuring[T](body: => T): T = { + val saved = printResults + printResults = false + try body + finally printResults = saved + } + def beSilentDuring[T](operation: => T): T = { + val saved = totalSilence + totalSilence = true + try operation + finally totalSilence = saved + } + + def quietRun[T](code: String) = beQuietDuring(interpret(code)) + + + private def logAndDiscard[T](label: String, alt: => T): PartialFunction[Throwable, T] = { + case t: ControlThrowable => throw t + case t: Throwable => + logDebug(label + ": " + unwrap(t)) + logDebug(stackTraceString(unwrap(t))) + alt + } + /** takes AnyRef because it may be binding a Throwable or an Exceptional */ + + private def withLastExceptionLock[T](body: => T, alt: => T): T = { + assert(bindExceptions, "withLastExceptionLock called incorrectly.") + bindExceptions = false + + try beQuietDuring(body) + catch logAndDiscard("withLastExceptionLock", alt) + finally bindExceptions = true + } + + def executionWrapper = _executionWrapper + def setExecutionWrapper(code: String) = _executionWrapper = code + def clearExecutionWrapper() = _executionWrapper = "" + + /** interpreter settings */ + lazy val isettings = new SparkISettings(this) + + /** Instantiate a compiler. Overridable. */ + protected def newCompiler(settings: Settings, reporter: Reporter): ReplGlobal = { + settings.outputDirs setSingleOutput virtualDirectory + settings.exposeEmptyPackage.value = true + new Global(settings, reporter) with ReplGlobal { + override def toString: String = "" + } + } + + /** + * Adds any specified jars to the compile and runtime classpaths. + * + * @note Currently only supports jars, not directories + * @param urls The list of items to add to the compile and runtime classpaths + */ + def addUrlsToClassPath(urls: URL*): Unit = { + new Run // Needed to force initialization of "something" to correctly load Scala classes from jars + urls.foreach(_runtimeClassLoader.addNewUrl) // Add jars/classes to runtime for execution + updateCompilerClassPath(urls: _*) // Add jars/classes to compile time for compiling + } + + protected def updateCompilerClassPath(urls: URL*): Unit = { + require(!global.forMSIL) // Only support JavaPlatform + + val platform = global.platform.asInstanceOf[JavaPlatform] + + val newClassPath = mergeUrlsIntoClassPath(platform, urls: _*) + + // NOTE: Must use reflection until this is exposed/fixed upstream in Scala + val fieldSetter = platform.getClass.getMethods + .find(_.getName.endsWith("currentClassPath_$eq")).get + fieldSetter.invoke(platform, Some(newClassPath)) + + // Reload all jars specified into our compiler + global.invalidateClassPathEntries(urls.map(_.getPath): _*) + } + + protected def mergeUrlsIntoClassPath(platform: JavaPlatform, urls: URL*): MergedClassPath[AbstractFile] = { + // Collect our new jars/directories and add them to the existing set of classpaths + val allClassPaths = ( + platform.classPath.asInstanceOf[MergedClassPath[AbstractFile]].entries ++ + urls.map(url => { + platform.classPath.context.newClassPath( + if (url.getProtocol == "file") { + val f = new File(url.getPath) + if (f.isDirectory) + io.AbstractFile.getDirectory(f) + else + io.AbstractFile.getFile(f) + } else { + io.AbstractFile.getURL(url) + } + ) + }) + ).distinct + + // Combine all of our classpaths (old and new) into one merged classpath + new MergedClassPath(allClassPaths, platform.classPath.context) + } + + /** Parent classloader. Overridable. */ + protected def parentClassLoader: ClassLoader = + SparkHelper.explicitParentLoader(settings).getOrElse( this.getClass.getClassLoader() ) + + /* A single class loader is used for all commands interpreted by this Interpreter. + It would also be possible to create a new class loader for each command + to interpret. The advantages of the current approach are: + + - Expressions are only evaluated one time. This is especially + significant for I/O, e.g. "val x = Console.readLine" + + The main disadvantage is: + + - Objects, classes, and methods cannot be rebound. Instead, definitions + shadow the old ones, and old code objects refer to the old + definitions. + */ + def resetClassLoader() = { + logDebug("Setting new classloader: was " + _classLoader) + _classLoader = null + ensureClassLoader() + } + final def ensureClassLoader() { + if (_classLoader == null) + _classLoader = makeClassLoader() + } + def classLoader: AbstractFileClassLoader = { + ensureClassLoader() + _classLoader + } + private class TranslatingClassLoader(parent: ClassLoader) extends AbstractFileClassLoader(virtualDirectory, parent) { + /** Overridden here to try translating a simple name to the generated + * class name if the original attempt fails. This method is used by + * getResourceAsStream as well as findClass. + */ + override protected def findAbstractFile(name: String): AbstractFile = { + super.findAbstractFile(name) match { + // deadlocks on startup if we try to translate names too early + case null if isInitializeComplete => + generatedName(name) map (x => super.findAbstractFile(x)) orNull + case file => + file + } + } + } + private def makeClassLoader(): AbstractFileClassLoader = + new TranslatingClassLoader(parentClassLoader match { + case null => ScalaClassLoader fromURLs compilerClasspath + case p => + _runtimeClassLoader = new URLClassLoader(compilerClasspath, p) with ExposeAddUrl + _runtimeClassLoader + }) + + def getInterpreterClassLoader() = classLoader + + // Set the current Java "context" class loader to this interpreter's class loader + def setContextClassLoader() = classLoader.setAsContext() + + /** Given a simple repl-defined name, returns the real name of + * the class representing it, e.g. for "Bippy" it may return + * {{{ + * $line19.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$Bippy + * }}} + */ + def generatedName(simpleName: String): Option[String] = { + if (simpleName endsWith nme.MODULE_SUFFIX_STRING) optFlatName(simpleName.init) map (_ + nme.MODULE_SUFFIX_STRING) + else optFlatName(simpleName) + } + def flatName(id: String) = optFlatName(id) getOrElse id + def optFlatName(id: String) = requestForIdent(id) map (_ fullFlatName id) + + def allDefinedNames = definedNameMap.keys.toList.sorted + def pathToType(id: String): String = pathToName(newTypeName(id)) + def pathToTerm(id: String): String = pathToName(newTermName(id)) + def pathToName(name: Name): String = { + if (definedNameMap contains name) + definedNameMap(name) fullPath name + else name.toString + } + + /** Most recent tree handled which wasn't wholly synthetic. */ + private def mostRecentlyHandledTree: Option[Tree] = { + prevRequests.reverse foreach { req => + req.handlers.reverse foreach { + case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member) + case _ => () + } + } + None + } + + /** Stubs for work in progress. */ + def handleTypeRedefinition(name: TypeName, old: Request, req: Request) = { + for (t1 <- old.simpleNameOfType(name) ; t2 <- req.simpleNameOfType(name)) { + logDebug("Redefining type '%s'\n %s -> %s".format(name, t1, t2)) + } + } + + def handleTermRedefinition(name: TermName, old: Request, req: Request) = { + for (t1 <- old.compilerTypeOf get name ; t2 <- req.compilerTypeOf get name) { + // Printing the types here has a tendency to cause assertion errors, like + // assertion failed: fatal: has owner value x, but a class owner is required + // so DBG is by-name now to keep it in the family. (It also traps the assertion error, + // but we don't want to unnecessarily risk hosing the compiler's internal state.) + logDebug("Redefining term '%s'\n %s -> %s".format(name, t1, t2)) + } + } + + def recordRequest(req: Request) { + if (req == null || referencedNameMap == null) + return + + prevRequests += req + req.referencedNames foreach (x => referencedNameMap(x) = req) + + // warning about serially defining companions. It'd be easy + // enough to just redefine them together but that may not always + // be what people want so I'm waiting until I can do it better. + for { + name <- req.definedNames filterNot (x => req.definedNames contains x.companionName) + oldReq <- definedNameMap get name.companionName + newSym <- req.definedSymbols get name + oldSym <- oldReq.definedSymbols get name.companionName + if Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule } + } { + afterTyper(replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")) + replwarn("Companions must be defined together; you may wish to use :paste mode for this.") + } + + // Updating the defined name map + req.definedNames foreach { name => + if (definedNameMap contains name) { + if (name.isTypeName) handleTypeRedefinition(name.toTypeName, definedNameMap(name), req) + else handleTermRedefinition(name.toTermName, definedNameMap(name), req) + } + definedNameMap(name) = req + } + } + + def replwarn(msg: => String) { + if (!settings.nowarnings.value) + printMessage(msg) + } + + def isParseable(line: String): Boolean = { + beSilentDuring { + try parse(line) match { + case Some(xs) => xs.nonEmpty // parses as-is + case None => true // incomplete + } + catch { case x: Exception => // crashed the compiler + replwarn("Exception in isParseable(\"" + line + "\"): " + x) + false + } + } + } + + def compileSourcesKeepingRun(sources: SourceFile*) = { + val run = new Run() + reporter.reset() + run compileSources sources.toList + (!reporter.hasErrors, run) + } + + /** Compile an nsc SourceFile. Returns true if there are + * no compilation errors, or false otherwise. + */ + def compileSources(sources: SourceFile*): Boolean = + compileSourcesKeepingRun(sources: _*)._1 + + /** Compile a string. Returns true if there are no + * compilation errors, or false otherwise. + */ + def compileString(code: String): Boolean = + compileSources(new BatchSourceFile("
    Output OperationMeaning
    print() print() Prints first ten elements of every batch of data in a DStream on the driver. - This is useful for development and debugging.
    saveAsObjectFiles(prefix, [suffix])