diff --git a/README.md b/README.md index 0a683a460ffac..5b09ad86849e7 100644 --- a/README.md +++ b/README.md @@ -4,8 +4,8 @@ Spark is a fast and general cluster computing system for Big Data. It provides high-level APIs in Scala, Java, and Python, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and structured -data processing, MLLib for machine learning, GraphX for graph processing, -and Spark Streaming. +data processing, MLlib for machine learning, GraphX for graph processing, +and Spark Streaming for stream processing. diff --git a/assembly/pom.xml b/assembly/pom.xml index de7b75258e3c5..4146168fc804b 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/bagel/pom.xml b/bagel/pom.xml index bd51b112e26fa..93db0d5efda5f 100644 --- a/bagel/pom.xml +++ b/bagel/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/bin/beeline b/bin/beeline index 1bda4dba50605..3fcb6df34339d 100755 --- a/bin/beeline +++ b/bin/beeline @@ -24,7 +24,7 @@ set -o posix # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" CLASS="org.apache.hive.beeline.BeeLine" exec "$FWDIR/bin/spark-class" $CLASS "$@" diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index 16b794a1592e8..15c6779402994 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -23,9 +23,9 @@ SCALA_VERSION=2.10 # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -. $FWDIR/bin/load-spark-env.sh +. "$FWDIR"/bin/load-spark-env.sh # Build up classpath CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH:$FWDIR/conf" @@ -63,7 +63,7 @@ else assembly_folder="$ASSEMBLY_DIR" fi -num_jars=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l) +num_jars="$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l)" if [ "$num_jars" -eq "0" ]; then echo "Failed to find Spark assembly in $assembly_folder" echo "You need to build Spark before running this program." @@ -77,7 +77,7 @@ if [ "$num_jars" -gt "1" ]; then exit 1 fi -ASSEMBLY_JAR=$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null) +ASSEMBLY_JAR="$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null)" # Verify that versions of java used to build the jars and run Spark are compatible jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1) @@ -103,8 +103,8 @@ else datanucleus_dir="$FWDIR"/lib_managed/jars fi -datanucleus_jars=$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar") -datanucleus_jars=$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g) +datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar")" +datanucleus_jars="$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g)" if [ -n "$datanucleus_jars" ]; then hive_files=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null) diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 493d3785a081b..6d4231b204595 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -25,9 +25,9 @@ if [ -z "$SPARK_ENV_LOADED" ]; then export SPARK_ENV_LOADED=1 # Returns the parent of the directory this script lives in. - parent_dir="$(cd `dirname $0`/..; pwd)" + parent_dir="$(cd "`dirname "$0"`"/..; pwd)" - user_conf_dir=${SPARK_CONF_DIR:-"$parent_dir/conf"} + user_conf_dir="${SPARK_CONF_DIR:-"$parent_dir"/conf}" if [ -f "${user_conf_dir}/spark-env.sh" ]; then # Promote all variable declarations to environment (exported) variables diff --git a/bin/pyspark b/bin/pyspark index f553b314c5991..5142411e36974 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -18,18 +18,18 @@ # # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" -source $FWDIR/bin/utils.sh +source "$FWDIR/bin/utils.sh" SCALA_VERSION=2.10 function usage() { echo "Usage: ./bin/pyspark [options]" 1>&2 - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 exit 0 } @@ -48,7 +48,7 @@ if [ ! -f "$FWDIR/RELEASE" ]; then fi fi -. $FWDIR/bin/load-spark-env.sh +. "$FWDIR"/bin/load-spark-env.sh # Figure out which Python executable to use if [[ -z "$PYSPARK_PYTHON" ]]; then @@ -57,12 +57,12 @@ fi export PYSPARK_PYTHON # Add the PySpark classes to the Python path: -export PYTHONPATH=$SPARK_HOME/python/:$PYTHONPATH -export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH +export PYTHONPATH="$SPARK_HOME/python/:$PYTHONPATH" +export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: -export OLD_PYTHONSTARTUP=$PYTHONSTARTUP -export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py +export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" +export PYTHONSTARTUP="$FWDIR/python/pyspark/shell.py" # If IPython options are specified, assume user wants to run IPython if [[ -n "$IPYTHON_OPTS" ]]; then @@ -85,6 +85,8 @@ export PYSPARK_SUBMIT_ARGS # For pyspark tests if [[ -n "$SPARK_TESTING" ]]; then + unset YARN_CONF_DIR + unset HADOOP_CONF_DIR if [[ -n "$PYSPARK_DOC_TEST" ]]; then exec "$PYSPARK_PYTHON" -m doctest $1 else @@ -97,10 +99,10 @@ fi if [[ "$1" =~ \.py$ ]]; then echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2 echo -e "Use ./bin/spark-submit \n" 1>&2 - primary=$1 + primary="$1" shift gatherSparkSubmitOpts "$@" - exec $FWDIR/bin/spark-submit "${SUBMISSION_OPTS[@]}" $primary "${APPLICATION_OPTS[@]}" + exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}" else # PySpark shell requires special handling downstream export PYSPARK_SHELL=1 diff --git a/bin/run-example b/bin/run-example index 68a35702eddd3..34dd71c71880e 100755 --- a/bin/run-example +++ b/bin/run-example @@ -19,7 +19,7 @@ SCALA_VERSION=2.10 -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" export SPARK_HOME="$FWDIR" EXAMPLES_DIR="$FWDIR"/examples @@ -35,12 +35,12 @@ else fi if [ -f "$FWDIR/RELEASE" ]; then - export SPARK_EXAMPLES_JAR=`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar` + export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`" elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then - export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar` + export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`" fi -if [[ -z $SPARK_EXAMPLES_JAR ]]; then +if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2 echo "You need to build Spark before running this program" 1>&2 exit 1 diff --git a/bin/spark-class b/bin/spark-class index c6543545a5e64..5f5f9ea74888d 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -27,12 +27,12 @@ esac SCALA_VERSION=2.10 # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" # Export this as SPARK_HOME export SPARK_HOME="$FWDIR" -. $FWDIR/bin/load-spark-env.sh +. "$FWDIR"/bin/load-spark-env.sh if [ -z "$1" ]; then echo "Usage: spark-class []" 1>&2 @@ -105,7 +105,7 @@ else exit 1 fi fi -JAVA_VERSION=$($RUNNER -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') +JAVA_VERSION=$("$RUNNER" -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') # Set JAVA_OPTS to be able to load native libraries and to set heap size if [ "$JAVA_VERSION" -ge 18 ]; then @@ -117,7 +117,7 @@ JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM" # Load extra JAVA_OPTS from conf/java-opts, if it exists if [ -e "$FWDIR/conf/java-opts" ] ; then - JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`" + JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`" fi # Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala! @@ -126,21 +126,21 @@ TOOLS_DIR="$FWDIR"/tools SPARK_TOOLS_JAR="" if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the SBT build - export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar` + export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar`" fi if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then # Use the JAR from the Maven build # TODO: this also needs to become an assembly! - export SPARK_TOOLS_JAR=`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar` + export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar`" fi # Compute classpath using external script -classpath_output=$($FWDIR/bin/compute-classpath.sh) +classpath_output=$("$FWDIR"/bin/compute-classpath.sh) if [[ "$?" != "0" ]]; then echo "$classpath_output" exit 1 else - CLASSPATH=$classpath_output + CLASSPATH="$classpath_output" fi if [[ "$1" =~ org.apache.spark.tools.* ]]; then @@ -153,9 +153,9 @@ if [[ "$1" =~ org.apache.spark.tools.* ]]; then fi if $cygwin; then - CLASSPATH=`cygpath -wp $CLASSPATH` + CLASSPATH="`cygpath -wp "$CLASSPATH"`" if [ "$1" == "org.apache.spark.tools.JavaAPICompletenessChecker" ]; then - export SPARK_TOOLS_JAR=`cygpath -w $SPARK_TOOLS_JAR` + export SPARK_TOOLS_JAR="`cygpath -w "$SPARK_TOOLS_JAR"`" fi fi export CLASSPATH diff --git a/bin/spark-shell b/bin/spark-shell index 0ab4e14f5b744..4a0670fc6c8aa 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -29,11 +29,11 @@ esac set -o posix ## Global script variables -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" function usage() { echo "Usage: ./bin/spark-shell [options]" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 exit 0 } @@ -41,7 +41,7 @@ if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then usage fi -source $FWDIR/bin/utils.sh +source "$FWDIR"/bin/utils.sh SUBMIT_USAGE_FUNCTION=usage gatherSparkSubmitOpts "$@" @@ -54,11 +54,11 @@ function main() { # (see https://github.com/sbt/sbt/issues/562). stty -icanon min 1 -echo > /dev/null 2>&1 export SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Djline.terminal=unix" - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" stty icanon echo > /dev/null 2>&1 else export SPARK_SUBMIT_OPTS - $FWDIR/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" + "$FWDIR"/bin/spark-submit --class org.apache.spark.repl.Main "${SUBMISSION_OPTS[@]}" spark-shell "${APPLICATION_OPTS[@]}" fi } diff --git a/bin/spark-sql b/bin/spark-sql index 2a3cb31f58e8d..ae096530cad04 100755 --- a/bin/spark-sql +++ b/bin/spark-sql @@ -27,7 +27,7 @@ CLASS="org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" CLASS_NOT_FOUND_EXIT_STATUS=1 # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" function usage { echo "Usage: ./bin/spark-sql [options] [cli option]" @@ -38,10 +38,10 @@ function usage { pattern+="\|--help" pattern+="\|=======" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 echo echo "CLI options:" - $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 } if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then @@ -49,7 +49,7 @@ if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then exit 0 fi -source $FWDIR/bin/utils.sh +source "$FWDIR"/bin/utils.sh SUBMIT_USAGE_FUNCTION=usage gatherSparkSubmitOpts "$@" diff --git a/bin/spark-submit b/bin/spark-submit index 277c4ce571ca2..c557311b4b20e 100755 --- a/bin/spark-submit +++ b/bin/spark-submit @@ -19,7 +19,7 @@ # NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala! -export SPARK_HOME="$(cd `dirname $0`/..; pwd)" +export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" ORIG_ARGS=("$@") while (($#)); do @@ -59,5 +59,5 @@ if [[ "$SPARK_SUBMIT_DEPLOY_MODE" == "client" && -f "$SPARK_SUBMIT_PROPERTIES_FI fi fi -exec $SPARK_HOME/bin/spark-class org.apache.spark.deploy.SparkSubmit "${ORIG_ARGS[@]}" +exec "$SPARK_HOME"/bin/spark-class org.apache.spark.deploy.SparkSubmit "${ORIG_ARGS[@]}" diff --git a/core/pom.xml b/core/pom.xml index 55bfe0b841ea4..b2b788a4bc13b 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cb4fb7cfbd32f..24d1a8f9eceae 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -49,6 +49,7 @@ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkD import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ +import org.apache.spark.SPARK_VERSION import org.apache.spark.ui.SparkUI import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} @@ -825,7 +826,7 @@ class SparkContext(config: SparkConf) extends Logging { } /** The version of Spark on which this application is running. */ - def version = SparkContext.SPARK_VERSION + def version = SPARK_VERSION /** * Return a map from the slave to the max memory available for caching and the remaining @@ -1261,7 +1262,10 @@ class SparkContext(config: SparkConf) extends Logging { /** Post the application start event */ private def postApplicationStart() { - listenerBus.post(SparkListenerApplicationStart(appName, startTime, sparkUser)) + // Note: this code assumes that the task scheduler has been initialized and has contacted + // the cluster manager to get an application ID (in case the cluster manager provides one). + listenerBus.post(SparkListenerApplicationStart(appName, taskScheduler.applicationId(), + startTime, sparkUser)) } /** Post the application end event */ @@ -1294,8 +1298,6 @@ class SparkContext(config: SparkConf) extends Logging { */ object SparkContext extends Logging { - private[spark] val SPARK_VERSION = "1.0.0" - private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 1642a2f8140c6..dd95e406f2a8e 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -220,7 +220,7 @@ object SparkEnv extends Logging { val shortShuffleMgrNames = Map( "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") - val shuffleMgrName = conf.get("spark.shuffle.manager", "hash") + val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) @@ -230,7 +230,7 @@ object SparkEnv extends Logging { val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) + new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index feeb6c02caa78..880f61c49726e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -758,6 +758,32 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) rdd.saveAsHadoopDataset(conf) } + /** + * Repartition the RDD according to the given partitioner and, within each resulting partition, + * sort records by their keys. + * + * This is more efficient than calling `repartition` and then sorting within each partition + * because it can push the sorting down into the shuffle machinery. + */ + def repartitionAndSortWithinPartitions(partitioner: Partitioner): JavaPairRDD[K, V] = { + val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]] + repartitionAndSortWithinPartitions(partitioner, comp) + } + + /** + * Repartition the RDD according to the given partitioner and, within each resulting partition, + * sort records by their keys. + * + * This is more efficient than calling `repartition` and then sorting within each partition + * because it can push the sorting down into the shuffle machinery. + */ + def repartitionAndSortWithinPartitions(partitioner: Partitioner, comp: Comparator[K]) + : JavaPairRDD[K, V] = { + implicit val ordering = comp // Allow implicit conversion of Comparator to Ordering. + fromRDD( + new OrderedRDDFunctions[K, V, (K, V)](rdd).repartitionAndSortWithinPartitions(partitioner)) + } + /** * Sort the RDD by key, so that each partition contains a sorted range of the elements in * ascending order. Calling `collect` or `save` on the resulting RDD will return or output an diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index a0e8bd403a41d..fbe39b27649f6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -34,15 +34,15 @@ private[spark] abstract class ApplicationHistoryProvider { * * @return List of all know applications. */ - def getListing(): Seq[ApplicationHistoryInfo] + def getListing(): Iterable[ApplicationHistoryInfo] /** * Returns the Spark UI for a specific application. * * @param appId The application ID. - * @return The application's UI, or null if application is not found. + * @return The application's UI, or None if application is not found. */ - def getAppUI(appId: String): SparkUI + def getAppUI(appId: String): Option[SparkUI] /** * Called when the server is shutting down. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 05c8a90782c74..481f6c93c6a8d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -32,6 +32,8 @@ import org.apache.spark.util.Utils private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider with Logging { + private val NOT_STARTED = "" + // Interval between each check for event log updates private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval", conf.getInt("spark.history.updateInterval", 10)) * 1000 @@ -47,8 +49,15 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis // A timestamp of when the disk was last accessed to check for log updates private var lastLogCheckTimeMs = -1L - // List of applications, in order from newest to oldest. - @volatile private var appList: Seq[ApplicationHistoryInfo] = Nil + // The modification time of the newest log detected during the last scan. This is used + // to ignore logs that are older during subsequent scans, to avoid processing data that + // is already known. + private var lastModifiedTime = -1L + + // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted + // into the map in order, so the LinkedHashMap maintains the correct ordering. + @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] + = new mutable.LinkedHashMap() /** * A background thread that periodically checks for event log updates on disk. @@ -93,15 +102,35 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis logCheckingThread.start() } - override def getListing() = appList + override def getListing() = applications.values - override def getAppUI(appId: String): SparkUI = { + override def getAppUI(appId: String): Option[SparkUI] = { try { - val appLogDir = fs.getFileStatus(new Path(resolvedLogDir.toString, appId)) - val (_, ui) = loadAppInfo(appLogDir, renderUI = true) - ui + applications.get(appId).map { info => + val (replayBus, appListener) = createReplayBus(fs.getFileStatus( + new Path(logDir, info.logDir))) + val ui = { + val conf = this.conf.clone() + val appSecManager = new SecurityManager(conf) + new SparkUI(conf, appSecManager, replayBus, appId, + s"${HistoryServer.UI_PATH_PREFIX}/$appId") + // Do not call ui.bind() to avoid creating a new server for each application + } + + replayBus.replay() + + ui.setAppName(s"${appListener.appName.getOrElse(NOT_STARTED)} ($appId)") + + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so they are properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) + ui.getSecurityManager.setViewAcls(appListener.sparkUser.getOrElse(NOT_STARTED), + appListener.viewAcls.getOrElse("")) + ui + } } catch { - case e: FileNotFoundException => null + case e: FileNotFoundException => None } } @@ -119,84 +148,79 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis try { val logStatus = fs.listStatus(new Path(resolvedLogDir)) val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() - val logInfos = logDirs.filter { dir => - fs.isFile(new Path(dir.getPath, EventLoggingListener.APPLICATION_COMPLETE)) - } - val currentApps = Map[String, ApplicationHistoryInfo]( - appList.map(app => app.id -> app):_*) - - // For any application that either (i) is not listed or (ii) has changed since the last time - // the listing was created (defined by the log dir's modification time), load the app's info. - // Otherwise just reuse what's already in memory. - val newApps = new mutable.ArrayBuffer[ApplicationHistoryInfo](logInfos.size) - for (dir <- logInfos) { - val curr = currentApps.getOrElse(dir.getPath().getName(), null) - if (curr == null || curr.lastUpdated < getModificationTime(dir)) { + // Load all new logs from the log directory. Only directories that have a modification time + // later than the last known log directory will be loaded. + var newLastModifiedTime = lastModifiedTime + val logInfos = logDirs + .filter { dir => + if (fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE))) { + val modTime = getModificationTime(dir) + newLastModifiedTime = math.max(newLastModifiedTime, modTime) + modTime > lastModifiedTime + } else { + false + } + } + .flatMap { dir => try { - val (app, _) = loadAppInfo(dir, renderUI = false) - newApps += app + val (replayBus, appListener) = createReplayBus(dir) + replayBus.replay() + Some(new FsApplicationHistoryInfo( + dir.getPath().getName(), + appListener.appId.getOrElse(dir.getPath().getName()), + appListener.appName.getOrElse(NOT_STARTED), + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + getModificationTime(dir), + appListener.sparkUser.getOrElse(NOT_STARTED))) } catch { - case e: Exception => logError(s"Failed to load app info from directory $dir.") + case e: Exception => + logInfo(s"Failed to load application log data from $dir.", e) + None + } + } + .sortBy { info => -info.endTime } + + lastModifiedTime = newLastModifiedTime + + // When there are new logs, merge the new list with the existing one, maintaining + // the expected ordering (descending end time). Maintaining the order is important + // to avoid having to sort the list every time there is a request for the log list. + if (!logInfos.isEmpty) { + val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() + def addIfAbsent(info: FsApplicationHistoryInfo) = { + if (!newApps.contains(info.id)) { + newApps += (info.id -> info) } - } else { - newApps += curr } - } - appList = newApps.sortBy { info => -info.endTime } + val newIterator = logInfos.iterator.buffered + val oldIterator = applications.values.iterator.buffered + while (newIterator.hasNext && oldIterator.hasNext) { + if (newIterator.head.endTime > oldIterator.head.endTime) { + addIfAbsent(newIterator.next) + } else { + addIfAbsent(oldIterator.next) + } + } + newIterator.foreach(addIfAbsent) + oldIterator.foreach(addIfAbsent) + + applications = newApps + } } catch { case t: Throwable => logError("Exception in checking for event log updates", t) } } - /** - * Parse the application's logs to find out the information we need to build the - * listing page. - * - * When creating the listing of available apps, there is no need to load the whole UI for the - * application. The UI is requested by the HistoryServer (by calling getAppInfo()) when the user - * clicks on a specific application. - * - * @param logDir Directory with application's log files. - * @param renderUI Whether to create the SparkUI for the application. - * @return A 2-tuple `(app info, ui)`. `ui` will be null if `renderUI` is false. - */ - private def loadAppInfo(logDir: FileStatus, renderUI: Boolean) = { - val path = logDir.getPath - val appId = path.getName + private def createReplayBus(logDir: FileStatus): (ReplayListenerBus, ApplicationEventListener) = { + val path = logDir.getPath() val elogInfo = EventLoggingListener.parseLoggingInfo(path, fs) val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec) val appListener = new ApplicationEventListener replayBus.addListener(appListener) - - val ui: SparkUI = if (renderUI) { - val conf = this.conf.clone() - val appSecManager = new SecurityManager(conf) - new SparkUI(conf, appSecManager, replayBus, appId, - HistoryServer.UI_PATH_PREFIX + s"/$appId") - // Do not call ui.bind() to avoid creating a new server for each application - } else { - null - } - - replayBus.replay() - val appInfo = ApplicationHistoryInfo( - appId, - appListener.appName, - appListener.startTime, - appListener.endTime, - getModificationTime(logDir), - appListener.sparkUser) - - if (ui != null) { - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) - // make sure to set admin acls before view acls so properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls) - ui.getSecurityManager.setViewAcls(appListener.sparkUser, appListener.viewAcls) - } - (appInfo, ui) + (replayBus, appListener) } /** Return when this directory was last modified. */ @@ -219,3 +243,13 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis private def getMonotonicTimeMs() = System.nanoTime() / (1000 * 1000) } + +private class FsApplicationHistoryInfo( + val logDir: String, + id: String, + name: String, + startTime: Long, + endTime: Long, + lastUpdated: Long, + sparkUser: String) + extends ApplicationHistoryInfo(id, name, startTime, endTime, lastUpdated, sparkUser) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d1a64c1912cb8..ce00c0ffd21e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -52,10 +52,7 @@ class HistoryServer( private val appLoader = new CacheLoader[String, SparkUI] { override def load(key: String): SparkUI = { - val ui = provider.getAppUI(key) - if (ui == null) { - throw new NoSuchElementException() - } + val ui = provider.getAppUI(key).getOrElse(throw new NoSuchElementException()) attachSparkUI(ui) ui } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index d7d19f6fa3b96..dd903dc65d204 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -123,6 +123,9 @@ private[spark] class Executor( env.metricsSystem.report() isStopped = true threadPool.shutdown() + if (!isLocal) { + env.stop() + } } class TaskRunner( diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 5cdbc306e56a0..e2fc9c649925e 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -44,4 +44,5 @@ package org.apache package object spark { // For package docs only + val SPARK_VERSION = "1.2.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index e98bad2026e32..d0dbfef35d03c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -19,7 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag -import org.apache.spark.{Logging, RangePartitioner} +import org.apache.spark.{Logging, Partitioner, RangePartitioner} import org.apache.spark.annotation.DeveloperApi /** @@ -64,4 +64,16 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, new ShuffledRDD[K, V, V](self, part) .setKeyOrdering(if (ascending) ordering else ordering.reverse) } + + /** + * Repartition the RDD according to the given partitioner and, within each resulting partition, + * sort records by their keys. + * + * This is more efficient than calling `repartition` and then sorting within each partition + * because it can push the sorting down into the shuffle machinery. + */ + def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = { + new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering) + } + } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index daea2617e62ea..a9b905b0d1a63 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -993,7 +993,7 @@ abstract class RDD[T: ClassTag]( */ @Experimental def countApproxDistinct(p: Int, sp: Int): Long = { - require(p >= 4, s"p ($p) must be greater than 0") + require(p >= 4, s"p ($p) must be at least 4") require(sp <= 32, s"sp ($sp) cannot be greater than 32") require(sp == 0 || p <= sp, s"p ($p) cannot be greater than sp ($sp)") val zeroCounter = new HyperLogLogPlus(p, sp) @@ -1064,11 +1064,10 @@ abstract class RDD[T: ClassTag]( // greater than totalParts because we actually cap it at totalParts in runJob. var numPartsToTry = 1 if (partsScanned > 0) { - // If we didn't find any rows after the first iteration, just try all partitions next. - // Otherwise, interpolate the number of partitions we need to try, but overestimate it - // by 50%. + // If we didn't find any rows after the previous iteration, quadruple and retry. Otherwise, + // interpolate the number of partitions we need to try, but overestimate it by 50%. if (buf.size == 0) { - numPartsToTry = totalParts - 1 + numPartsToTry = partsScanned * 4 } else { numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt } @@ -1128,15 +1127,19 @@ abstract class RDD[T: ClassTag]( * @return an array of top elements */ def takeOrdered(num: Int)(implicit ord: Ordering[T]): Array[T] = { - mapPartitions { items => - // Priority keeps the largest elements, so let's reverse the ordering. - val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) - Iterator.single(queue) - }.reduce { (queue1, queue2) => - queue1 ++= queue2 - queue1 - }.toArray.sorted(ord) + if (num == 0) { + Array.empty + } else { + mapPartitions { items => + // Priority keeps the largest elements, so let's reverse the ordering. + val queue = new BoundedPriorityQueue[T](num)(ord.reverse) + queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + Iterator.single(queue) + }.reduce { (queue1, queue2) => + queue1 ++= queue2 + queue1 + }.toArray.sorted(ord) + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala index 162158babc35b..6d39a5e3fa64c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -24,38 +24,31 @@ package org.apache.spark.scheduler * from multiple applications are seen, the behavior is unspecified. */ private[spark] class ApplicationEventListener extends SparkListener { - var appName = "" - var sparkUser = "" - var startTime = -1L - var endTime = -1L - var viewAcls = "" - var adminAcls = "" - - def applicationStarted = startTime != -1 - - def applicationCompleted = endTime != -1 - - def applicationDuration: Long = { - val difference = endTime - startTime - if (applicationStarted && applicationCompleted && difference > 0) difference else -1L - } + var appName: Option[String] = None + var appId: Option[String] = None + var sparkUser: Option[String] = None + var startTime: Option[Long] = None + var endTime: Option[Long] = None + var viewAcls: Option[String] = None + var adminAcls: Option[String] = None override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { - appName = applicationStart.appName - startTime = applicationStart.time - sparkUser = applicationStart.sparkUser + appName = Some(applicationStart.appName) + appId = applicationStart.appId + startTime = Some(applicationStart.time) + sparkUser = Some(applicationStart.sparkUser) } override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { - endTime = applicationEnd.time + endTime = Some(applicationEnd.time) } override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { val environmentDetails = environmentUpdate.environmentDetails val allProperties = environmentDetails("Spark Properties").toMap - viewAcls = allProperties.getOrElse("spark.ui.view.acls", "") - adminAcls = allProperties.getOrElse("spark.admin.acls", "") + viewAcls = allProperties.get("spark.ui.view.acls") + adminAcls = allProperties.get("spark.admin.acls") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2ccc27324ac8c..6fcf9e31543ed 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -241,9 +241,9 @@ class DAGScheduler( callSite: CallSite) : Stage = { + val parentStages = getParentStages(rdd, jobId) val id = nextStageId.getAndIncrement() - val stage = - new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) + val stage = new Stage(id, rdd, numTasks, shuffleDep, parentStages, jobId, callSite) stageIdToStage(id) = stage updateJobIdStageIdMaps(jobId, stage) stage diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 4b99f630440ad..64b32ae0edaac 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -29,6 +29,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec +import org.apache.spark.SPARK_VERSION import org.apache.spark.util.{FileLogger, JsonProtocol, Utils} /** @@ -86,7 +87,7 @@ private[spark] class EventLoggingListener( sparkConf.get("spark.io.compression.codec", CompressionCodec.DEFAULT_COMPRESSION_CODEC) logger.newFile(COMPRESSION_CODEC_PREFIX + codec) } - logger.newFile(SPARK_VERSION_PREFIX + SparkContext.SPARK_VERSION) + logger.newFile(SPARK_VERSION_PREFIX + SPARK_VERSION) logger.newFile(LOG_PREFIX + logger.fileIndex) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala index e41e0a9841691..a0be8307eff27 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala @@ -31,4 +31,12 @@ private[spark] trait SchedulerBackend { def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = throw new UnsupportedOperationException def isReady(): Boolean = true + + /** + * The application ID associated with the job, if any. + * + * @return The application ID, or None if the backend does not provide an ID. + */ + def applicationId(): Option[String] = None + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 86ca8445a1124..86afe3bd5265f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -67,11 +67,11 @@ case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(S extends SparkListenerEvent @DeveloperApi -case class SparkListenerBlockManagerAdded(blockManagerId: BlockManagerId, maxMem: Long) +case class SparkListenerBlockManagerAdded(time: Long, blockManagerId: BlockManagerId, maxMem: Long) extends SparkListenerEvent @DeveloperApi -case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId) +case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockManagerId) extends SparkListenerEvent @DeveloperApi @@ -89,8 +89,8 @@ case class SparkListenerExecutorMetricsUpdate( extends SparkListenerEvent @DeveloperApi -case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String) - extends SparkListenerEvent +case class SparkListenerApplicationStart(appName: String, appId: Option[String], time: Long, + sparkUser: String) extends SparkListenerEvent @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 1a0b877c8a5e1..1c1ce666eab0f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -64,4 +64,12 @@ private[spark] trait TaskScheduler { */ def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)], blockManagerId: BlockManagerId): Boolean + + /** + * The application ID associated with the job, if any. + * + * @return The application ID, or None if the backend does not provide an ID. + */ + def applicationId(): Option[String] = None + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ad051e59af86d..633e892554c50 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -491,6 +491,9 @@ private[spark] class TaskSchedulerImpl( } } } + + override def applicationId(): Option[String] = backend.applicationId() + } @@ -535,4 +538,5 @@ private[spark] object TaskSchedulerImpl { retval.toList } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 2a3711ae2a78c..5b5257269d92f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -51,12 +51,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A val conf = scheduler.sc.conf private val timeout = AkkaUtils.askTimeout(conf) private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - // Submit tasks only after (registered resources / total expected resources) + // Submit tasks only after (registered resources / total expected resources) // is equal to at least this value, that is double between 0 and 1. var minRegisteredRatio = math.min(1, conf.getDouble("spark.scheduler.minRegisteredResourcesRatio", 0)) // Submit tasks after maxRegisteredWaitingTime milliseconds - // if minRegisteredRatio has not yet been reached + // if minRegisteredRatio has not yet been reached val maxRegisteredWaitingTime = conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000) val createTime = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index bc7670f4a804d..513d74a08a47f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -69,4 +69,5 @@ private[spark] class SimrSchedulerBackend( fs.delete(new Path(driverFilePath), false) super.stop() } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 32138e5246700..06872ace2ecf4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -34,6 +34,10 @@ private[spark] class SparkDeploySchedulerBackend( var client: AppClient = null var stopping = false var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _ + var appId: String = _ + + val registrationLock = new Object() + var registrationDone = false val maxCores = conf.getOption("spark.cores.max").map(_.toInt) val totalExpectedCores = maxCores.getOrElse(0) @@ -68,6 +72,8 @@ private[spark] class SparkDeploySchedulerBackend( client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() + + waitForRegistration() } override def stop() { @@ -81,15 +87,19 @@ private[spark] class SparkDeploySchedulerBackend( override def connected(appId: String) { logInfo("Connected to Spark cluster with app ID " + appId) + this.appId = appId + notifyContext() } override def disconnected() { + notifyContext() if (!stopping) { logWarning("Disconnected from Spark cluster! Waiting for reconnection...") } } override def dead(reason: String) { + notifyContext() if (!stopping) { logError("Application has been killed. Reason: " + reason) scheduler.error(reason) @@ -116,4 +126,22 @@ private[spark] class SparkDeploySchedulerBackend( override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get() >= totalExpectedCores * minRegisteredRatio } + + override def applicationId(): Option[String] = Option(appId) + + private def waitForRegistration() = { + registrationLock.synchronized { + while (!registrationDone) { + registrationLock.wait() + } + } + } + + private def notifyContext() = { + registrationLock.synchronized { + registrationDone = true + registrationLock.notifyAll() + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 87e181e773fdf..64568409dbafd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -71,11 +71,6 @@ private[spark] class CoarseMesosSchedulerBackend( val taskIdToSlaveId = new HashMap[Int, String] val failuresBySlaveId = new HashMap[String, Int] // How many times tasks on each slave failed - val executorSparkHome = conf.getOption("spark.mesos.executor.home") - .orElse(sc.getSparkHome()) - .getOrElse { - throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") - } val extraCoresPerSlave = conf.getInt("spark.mesos.extra.cores", 0) @@ -112,6 +107,11 @@ private[spark] class CoarseMesosSchedulerBackend( } def createCommand(offer: Offer, numCores: Int): CommandInfo = { + val executorSparkHome = conf.getOption("spark.mesos.executor.home") + .orElse(sc.getSparkHome()) + .getOrElse { + throw new SparkException("Executor Spark home `spark.mesos.executor.home` is not set!") + } val environment = Environment.newBuilder() val extraClassPath = conf.getOption("spark.executor.extraClassPath") extraClassPath.foreach { cp => @@ -309,4 +309,5 @@ private[spark] class CoarseMesosSchedulerBackend( logInfo("Executor lost: %s, marking slave %s as lost".format(e.getValue, s.getValue)) slaveLost(d, s) } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 67ee4d66f151b..a9ef126f5de0e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -349,4 +349,5 @@ private[spark] class MesosSchedulerBackend( // TODO: query Mesos for number of cores override def defaultParallelism() = sc.conf.getInt("spark.default.parallelism", 8) + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index bec9502f20466..9ea25c2bc7090 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -114,4 +114,5 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { localActor ! StatusUpdate(taskId, state, serializedData) } + } diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala index 292ac0d663665..439981d232349 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala @@ -156,7 +156,7 @@ class FileShuffleBlockManager(conf: SparkConf) val filename = physicalFileName(shuffleId, bucketId, fileId) blockManager.diskBlockManager.getFile(filename) } - val fileGroup = new ShuffleFileGroup(fileId, shuffleId, files) + val fileGroup = new ShuffleFileGroup(shuffleId, fileId, files) shuffleState.allFileGroups.add(fileGroup) fileGroup } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index e67b3dc5ce02e..2e262594b3538 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -27,7 +27,11 @@ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] -class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Logging { +class BlockManagerMaster( + var driverActor: ActorRef, + conf: SparkConf, + isDriver: Boolean) + extends Logging { private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) @@ -196,7 +200,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log /** Stop the driver actor, called only on the Spark driver node */ def stop() { - if (driverActor != null) { + if (driverActor != null && isDriver) { tell(StopBlockManagerMaster) driverActor = null logInfo("BlockManagerMaster stopped") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 3ab07703b6f85..1a6c7cb24f9ac 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -203,7 +203,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockLocations.remove(blockId) } } - listenerBus.post(SparkListenerBlockManagerRemoved(blockManagerId)) + listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId)) } private def expireDeadHosts() { @@ -325,6 +325,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(manager) => @@ -340,9 +341,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus id.hostPort, Utils.bytesToString(maxMemSize))) blockManagerInfo(id) = - new BlockManagerInfo(id, System.currentTimeMillis(), maxMemSize, slaveActor) + new BlockManagerInfo(id, time, maxMemSize, slaveActor) } - listenerBus.post(SparkListenerBlockManagerAdded(id, maxMemSize)) + listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } private def updateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index bee6dad3387e5..f0006b42aee4f 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -232,7 +232,7 @@ private[spark] object UIUtils extends Logging { def listingTable[T]( headers: Seq[String], generateDataRow: T => Seq[Node], - data: Seq[T], + data: Iterable[T], fixedWidth: Boolean = false): Seq[Node] = { var listingTableClass = TABLE_CLASS diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index a7543454eca1f..b0754e3ce10db 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -152,13 +152,15 @@ private[spark] object JsonProtocol { val blockManagerId = blockManagerIdToJson(blockManagerAdded.blockManagerId) ("Event" -> Utils.getFormattedClassName(blockManagerAdded)) ~ ("Block Manager ID" -> blockManagerId) ~ - ("Maximum Memory" -> blockManagerAdded.maxMem) + ("Maximum Memory" -> blockManagerAdded.maxMem) ~ + ("Timestamp" -> blockManagerAdded.time) } def blockManagerRemovedToJson(blockManagerRemoved: SparkListenerBlockManagerRemoved): JValue = { val blockManagerId = blockManagerIdToJson(blockManagerRemoved.blockManagerId) ("Event" -> Utils.getFormattedClassName(blockManagerRemoved)) ~ - ("Block Manager ID" -> blockManagerId) + ("Block Manager ID" -> blockManagerId) ~ + ("Timestamp" -> blockManagerRemoved.time) } def unpersistRDDToJson(unpersistRDD: SparkListenerUnpersistRDD): JValue = { @@ -169,6 +171,7 @@ private[spark] object JsonProtocol { def applicationStartToJson(applicationStart: SparkListenerApplicationStart): JValue = { ("Event" -> Utils.getFormattedClassName(applicationStart)) ~ ("App Name" -> applicationStart.appName) ~ + ("App ID" -> applicationStart.appId.map(JString(_)).getOrElse(JNothing)) ~ ("Timestamp" -> applicationStart.time) ~ ("User" -> applicationStart.sparkUser) } @@ -466,12 +469,14 @@ private[spark] object JsonProtocol { def blockManagerAddedFromJson(json: JValue): SparkListenerBlockManagerAdded = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") val maxMem = (json \ "Maximum Memory").extract[Long] - SparkListenerBlockManagerAdded(blockManagerId, maxMem) + val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + SparkListenerBlockManagerAdded(time, blockManagerId, maxMem) } def blockManagerRemovedFromJson(json: JValue): SparkListenerBlockManagerRemoved = { val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") - SparkListenerBlockManagerRemoved(blockManagerId) + val time = Utils.jsonOption(json \ "Timestamp").map(_.extract[Long]).getOrElse(-1L) + SparkListenerBlockManagerRemoved(time, blockManagerId) } def unpersistRDDFromJson(json: JValue): SparkListenerUnpersistRDD = { @@ -480,9 +485,10 @@ private[spark] object JsonProtocol { def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = { val appName = (json \ "App Name").extract[String] + val appId = Utils.jsonOption(json \ "App ID").map(_.extract[String]) val time = (json \ "Timestamp").extract[Long] val sparkUser = (json \ "User").extract[String] - SparkListenerApplicationStart(appName, time, sparkUser) + SparkListenerApplicationStart(appName, appId, time, sparkUser) } def applicationEndFromJson(json: JValue): SparkListenerApplicationEnd = { diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index e1c13de04a0be..be99dc501c4b2 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -189,6 +189,36 @@ public void sortByKey() { Assert.assertEquals(new Tuple2(3, 2), sortedPairs.get(2)); } + @Test + public void repartitionAndSortWithinPartitions() { + List> pairs = new ArrayList>(); + pairs.add(new Tuple2(0, 5)); + pairs.add(new Tuple2(3, 8)); + pairs.add(new Tuple2(2, 6)); + pairs.add(new Tuple2(0, 8)); + pairs.add(new Tuple2(3, 8)); + pairs.add(new Tuple2(1, 3)); + + JavaPairRDD rdd = sc.parallelizePairs(pairs); + + Partitioner partitioner = new Partitioner() { + public int numPartitions() { + return 2; + } + public int getPartition(Object key) { + return ((Integer)key).intValue() % 2; + } + }; + + JavaPairRDD repartitioned = + rdd.repartitionAndSortWithinPartitions(partitioner); + List>> partitions = repartitioned.glom().collect(); + Assert.assertEquals(partitions.get(0), Arrays.asList(new Tuple2(0, 5), + new Tuple2(0, 8), new Tuple2(2, 6))); + Assert.assertEquals(partitions.get(1), Arrays.asList(new Tuple2(1, 3), + new Tuple2(3, 8), new Tuple2(3, 8))); + } + @Test public void emptyRDD() { JavaRDD rdd = sc.emptyRDD(); diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala new file mode 100644 index 0000000000000..2acc02a54fa3d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.BeforeAndAfterAll + +class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { + + // This test suite should run all tests in ShuffleSuite with hash-based shuffle. + + override def beforeAll() { + System.setProperty("spark.shuffle.manager", "hash") + } + + override def afterAll() { + System.clearProperty("spark.shuffle.manager") + } +} diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index b13ddf96bc77c..15aa4d83800fa 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.MutablePair -class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { +abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext { val conf = new SparkConf(loadDefaults = false) diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index 5c02c00586ef4..639e56c488db4 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -24,8 +24,7 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with sort-based shuffle. override def beforeAll() { - System.setProperty("spark.shuffle.manager", - "org.apache.spark.shuffle.sort.SortShuffleManager") + System.setProperty("spark.shuffle.manager", "sort") } override def afterAll() { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 926d4fecb5b91..c1b501a75c8b8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -521,6 +521,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sortedLowerK === Array(1, 2, 3, 4, 5)) } + test("takeOrdered with limit 0") { + val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + val rdd = sc.makeRDD(nums, 2) + val sortedLowerK = rdd.takeOrdered(0) + assert(sortedLowerK.size === 0) + } + test("takeOrdered with custom ordering") { val nums = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) implicit val ord = implicitly[Ordering[Int]].reverse @@ -675,6 +682,20 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered) } + test("repartitionAndSortWithinPartitions") { + val data = sc.parallelize(Seq((0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)), 2) + + val partitioner = new Partitioner { + def numPartitions: Int = 2 + def getPartition(key: Any): Int = key.asInstanceOf[Int] % 2 + } + + val repartitioned = data.repartitionAndSortWithinPartitions(partitioner) + val partitions = repartitioned.glom().collect() + assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6))) + assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8))) + } + test("intersection") { val all = sc.parallelize(1 to 10) val evens = sc.parallelize(2 to 10 by 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1a42fc1b233ba..aa73469b6acd8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} @@ -97,10 +98,12 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F /** Length of time to wait while draining listener events. */ val WAIT_TIMEOUT_MILLIS = 10000 val sparkListener = new SparkListener() { - val successfulStages = new HashSet[Int]() - val failedStages = new ArrayBuffer[Int]() + val successfulStages = new HashSet[Int] + val failedStages = new ArrayBuffer[Int] + val stageByOrderOfExecution = new ArrayBuffer[Int] override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { val stageInfo = stageCompleted.stageInfo + stageByOrderOfExecution += stageInfo.stageId if (stageInfo.failureReason.isEmpty) { successfulStages += stageInfo.stageId } else { @@ -120,7 +123,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F */ val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]] // stub out BlockManagerMaster.getLocations to use our cacheLocations - val blockManagerMaster = new BlockManagerMaster(null, conf) { + val blockManagerMaster = new BlockManagerMaster(null, conf, true) { override def getLocations(blockIds: Array[BlockId]): Seq[Seq[BlockManagerId]] = { blockIds.map { _.asRDDId.map(id => (id.rddId -> id.splitIndex)).flatMap(key => cacheLocations.get(key)). @@ -231,6 +234,13 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F runEvent(JobCancelled(jobId)) } + test("[SPARK-3353] parent stage should have lower stage id") { + sparkListener.stageByOrderOfExecution.clear() + sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count() + assert(sparkListener.stageByOrderOfExecution.length === 2) + assert(sparkListener.stageByOrderOfExecution(0) < sparkListener.stageByOrderOfExecution(1)) + } + test("zero split job") { var numResults = 0 val fakeListener = new JobListener() { @@ -457,7 +467,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F null, null)) assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) - assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.contains(1)) // The second ResultTask fails, with a fetch failure for the output from the second mapper. runEvent(CompletionEvent( @@ -515,8 +525,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F // Listener bus should get told about the map stage failing, but not the reduce stage // (since the reduce stage hasn't been started yet). assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) - assert(sparkListener.failedStages.contains(1)) - assert(sparkListener.failedStages.size === 1) + assert(sparkListener.failedStages.toSet === Set(0)) assertDataStructuresEmpty } @@ -563,14 +572,12 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F val stageFailureMessage = "Exception failure in map stage" failed(taskSets(0), stageFailureMessage) - assert(cancelledStages.contains(1)) + assert(cancelledStages.toSet === Set(0, 2)) // Make sure the listeners got told about both failed stages. assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) assert(sparkListener.successfulStages.isEmpty) - assert(sparkListener.failedStages.contains(1)) - assert(sparkListener.failedStages.contains(3)) - assert(sparkListener.failedStages.size === 2) + assert(sparkListener.failedStages.toSet === Set(0, 2)) assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 41e58a008c533..e5315bc93e217 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec +import org.apache.spark.SPARK_VERSION import org.apache.spark.util.{JsonProtocol, Utils} import java.io.File @@ -196,7 +197,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { def assertInfoCorrect(info: EventLoggingInfo, loggerStopped: Boolean) { assert(info.logPaths.size > 0) - assert(info.sparkVersion === SparkContext.SPARK_VERSION) + assert(info.sparkVersion === SPARK_VERSION) assert(info.compressionCodec.isDefined === compressionCodec.isDefined) info.compressionCodec.foreach { codec => assert(compressionCodec.isDefined) @@ -229,7 +230,8 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { val conf = getLoggingConf(logDirPath, compressionCodec) val eventLogger = new EventLoggingListener("test", conf) val listenerBus = new LiveListenerBus - val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", 125L, "Mickey") + val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, + 125L, "Mickey") val applicationEnd = SparkListenerApplicationEnd(1000L) // A comprehensive test on JSON de/serialization of all events is in JsonProtocolSuite @@ -380,7 +382,7 @@ class EventLoggingListenerSuite extends FunSuite with BeforeAndAfter { private def assertSparkVersionIsValid(logFiles: Array[FileStatus]) { val file = logFiles.map(_.getPath.getName).find(EventLoggingListener.isSparkVersionFile) assert(file.isDefined) - assert(EventLoggingListener.parseSparkVersion(file.get) === SparkContext.SPARK_VERSION) + assert(EventLoggingListener.parseSparkVersion(file.get) === SPARK_VERSION) } private def assertCompressionCodecIsValid(logFiles: Array[FileStatus], compressionCodec: String) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 8f0ee9f4dbafd..7ab351d1b4d24 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -83,7 +83,8 @@ class ReplayListenerSuite extends FunSuite with BeforeAndAfter { val fstream = fileSystem.create(logFilePath) val cstream = codec.map(_.compressedOutputStream(fstream)).getOrElse(fstream) val writer = new PrintWriter(cstream) - val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", 125L, "Mickey") + val applicationStart = SparkListenerApplicationStart("Greatest App (N)ever", None, + 125L, "Mickey") val applicationEnd = SparkListenerApplicationEnd(1000L) writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationStart)))) writer.println(compact(render(JsonProtocol.sparkEventToJson(applicationEnd)))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 3b0b8e2f68c97..ab35e8edc4ebf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -180,7 +180,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers rdd3.count() assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) listener.stageInfos.size should be {2} // Shuffle map stage + result stage - val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 2).get + val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 3).get stageInfo3.rddInfos.size should be {1} // ShuffledRDD stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true} stageInfo3.rddInfos.exists(_.name == "Trois") should be {true} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 5a015e2521916..e251660dae5de 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -92,7 +92,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter master = new BlockManagerMaster( actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), - conf) + conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() diff --git a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala index 4e022a69c8212..3a45875391e29 100644 --- a/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/StorageStatusListenerSuite.scala @@ -36,13 +36,13 @@ class StorageStatusListenerSuite extends FunSuite { // Block manager add assert(listener.executorIdToStorageStatus.size === 0) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) assert(listener.executorIdToStorageStatus.size === 1) assert(listener.executorIdToStorageStatus.get("big").isDefined) assert(listener.executorIdToStorageStatus("big").blockManagerId === bm1) assert(listener.executorIdToStorageStatus("big").maxMem === 1000L) assert(listener.executorIdToStorageStatus("big").numBlocks === 0) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) assert(listener.executorIdToStorageStatus.size === 2) assert(listener.executorIdToStorageStatus.get("fat").isDefined) assert(listener.executorIdToStorageStatus("fat").blockManagerId === bm2) @@ -50,11 +50,11 @@ class StorageStatusListenerSuite extends FunSuite { assert(listener.executorIdToStorageStatus("fat").numBlocks === 0) // Block manager remove - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(bm1)) + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm1)) assert(listener.executorIdToStorageStatus.size === 1) assert(!listener.executorIdToStorageStatus.get("big").isDefined) assert(listener.executorIdToStorageStatus.get("fat").isDefined) - listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(bm2)) + listener.onBlockManagerRemoved(SparkListenerBlockManagerRemoved(1L, bm2)) assert(listener.executorIdToStorageStatus.size === 0) assert(!listener.executorIdToStorageStatus.get("big").isDefined) assert(!listener.executorIdToStorageStatus.get("fat").isDefined) @@ -62,8 +62,8 @@ class StorageStatusListenerSuite extends FunSuite { test("task end without updated blocks") { val listener = new StorageStatusListener - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L)) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) val taskMetrics = new TaskMetrics // Task end with no updated blocks @@ -79,8 +79,8 @@ class StorageStatusListenerSuite extends FunSuite { test("task end with updated blocks") { val listener = new StorageStatusListener - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L)) - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm2, 2000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm2, 2000L)) val taskMetrics1 = new TaskMetrics val taskMetrics2 = new TaskMetrics val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L)) @@ -128,7 +128,7 @@ class StorageStatusListenerSuite extends FunSuite { test("unpersist RDD") { val listener = new StorageStatusListener - listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(bm1, 1000L)) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) val taskMetrics1 = new TaskMetrics val taskMetrics2 = new TaskMetrics val block1 = (RDDBlockId(1, 1), BlockStatus(StorageLevel.DISK_ONLY, 0L, 100L, 0L)) diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index d9e9c70a8a9e7..e1bc1379b5d80 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -108,7 +108,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val myRddInfo1 = rddInfo1 val myRddInfo2 = rddInfo2 val stageInfo0 = new StageInfo(0, 0, "0", 100, Seq(myRddInfo0, myRddInfo1, myRddInfo2), "details") - bus.postToAll(SparkListenerBlockManagerAdded(bm1, 1000L)) + bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener._rddInfoMap.size === 3) assert(storageListener.rddInfoList.size === 0) // not cached @@ -175,7 +175,7 @@ class StorageTabSuite extends FunSuite with BeforeAndAfter { val block1 = (RDDBlockId(1, 1), BlockStatus(memOnly, 200L, 0L, 0L)) taskMetrics0.updatedBlocks = Some(Seq(block0)) taskMetrics1.updatedBlocks = Some(Seq(block1)) - bus.postToAll(SparkListenerBlockManagerAdded(bm1, 1000L)) + bus.postToAll(SparkListenerBlockManagerAdded(1L, bm1, 1000L)) bus.postToAll(SparkListenerStageSubmitted(stageInfo0)) assert(storageListener.rddInfoList.size === 0) bus.postToAll(SparkListenerTaskEnd(0, 0, "big", Success, taskInfo, taskMetrics0)) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 66a17de9ec9ce..2b45d8b695853 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -21,6 +21,9 @@ import java.util.Properties import scala.collection.Map +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.scalatest.FunSuite @@ -52,12 +55,12 @@ class JsonProtocolSuite extends FunSuite { "System Properties" -> Seq(("Username", "guest"), ("Password", "guest")), "Classpath Entries" -> Seq(("Super library", "/tmp/super_library")) )) - val blockManagerAdded = SparkListenerBlockManagerAdded( + val blockManagerAdded = SparkListenerBlockManagerAdded(1L, BlockManagerId("Stars", "In your multitude...", 300), 500) - val blockManagerRemoved = SparkListenerBlockManagerRemoved( + val blockManagerRemoved = SparkListenerBlockManagerRemoved(2L, BlockManagerId("Scarce", "to be counted...", 100)) val unpersistRdd = SparkListenerUnpersistRDD(12345) - val applicationStart = SparkListenerApplicationStart("The winner of all", 42L, "Garfield") + val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield") val applicationEnd = SparkListenerApplicationEnd(42L) testEvent(stageSubmitted, stageSubmittedJsonString) @@ -151,6 +154,35 @@ class JsonProtocolSuite extends FunSuite { assert(newMetrics.inputMetrics.isEmpty) } + test("BlockManager events backward compatibility") { + // SparkListenerBlockManagerAdded/Removed in Spark 1.0.0 do not have a "time" property. + val blockManagerAdded = SparkListenerBlockManagerAdded(1L, + BlockManagerId("Stars", "In your multitude...", 300), 500) + val blockManagerRemoved = SparkListenerBlockManagerRemoved(2L, + BlockManagerId("Scarce", "to be counted...", 100)) + + val oldBmAdded = JsonProtocol.blockManagerAddedToJson(blockManagerAdded) + .removeField({ _._1 == "Timestamp" }) + + val deserializedBmAdded = JsonProtocol.blockManagerAddedFromJson(oldBmAdded) + assert(SparkListenerBlockManagerAdded(-1L, blockManagerAdded.blockManagerId, + blockManagerAdded.maxMem) === deserializedBmAdded) + + val oldBmRemoved = JsonProtocol.blockManagerRemovedToJson(blockManagerRemoved) + .removeField({ _._1 == "Timestamp" }) + + val deserializedBmRemoved = JsonProtocol.blockManagerRemovedFromJson(oldBmRemoved) + assert(SparkListenerBlockManagerRemoved(-1L, blockManagerRemoved.blockManagerId) === + deserializedBmRemoved) + } + + test("SparkListenerApplicationStart backwards compatibility") { + // SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property. + val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user") + val oldEvent = JsonProtocol.applicationStartToJson(applicationStart) + .removeField({ _._1 == "App ID" }) + assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent)) + } /** -------------------------- * | Helper test running methods | @@ -242,8 +274,10 @@ class JsonProtocolSuite extends FunSuite { assertEquals(e1.environmentDetails, e2.environmentDetails) case (e1: SparkListenerBlockManagerAdded, e2: SparkListenerBlockManagerAdded) => assert(e1.maxMem === e2.maxMem) + assert(e1.time === e2.time) assertEquals(e1.blockManagerId, e2.blockManagerId) case (e1: SparkListenerBlockManagerRemoved, e2: SparkListenerBlockManagerRemoved) => + assert(e1.time === e2.time) assertEquals(e1.blockManagerId, e2.blockManagerId) case (e1: SparkListenerUnpersistRDD, e2: SparkListenerUnpersistRDD) => assert(e1.rddId == e2.rddId) @@ -945,7 +979,8 @@ class JsonProtocolSuite extends FunSuite { | "Host": "In your multitude...", | "Port": 300 | }, - | "Maximum Memory": 500 + | "Maximum Memory": 500, + | "Timestamp": 1 |} """ @@ -957,7 +992,8 @@ class JsonProtocolSuite extends FunSuite { | "Executor ID": "Scarce", | "Host": "to be counted...", | "Port": 100 - | } + | }, + | "Timestamp": 2 |} """ diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index ac3931e3d0a73..511d76c9144cc 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -42,6 +42,7 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { conf.set("spark.serializer.objectStreamReset", "1") conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer") conf.set("spark.shuffle.spill.compress", codec.isDefined.toString) + conf.set("spark.shuffle.compress", codec.isDefined.toString) codec.foreach { c => conf.set("spark.io.compression.codec", c) } // Ensure that we actually have multiple batches per spill file conf.set("spark.shuffle.spill.batchSize", "10") diff --git a/dev/check-license b/dev/check-license index 625ec161bc571..9ff0929e9a5e8 100755 --- a/dev/check-license +++ b/dev/check-license @@ -23,18 +23,18 @@ acquire_rat_jar () { URL1="http://search.maven.org/remotecontent?filepath=org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" URL2="http://repo1.maven.org/maven2/org/apache/rat/apache-rat/${RAT_VERSION}/apache-rat-${RAT_VERSION}.jar" - JAR=$rat_jar + JAR="$rat_jar" if [[ ! -f "$rat_jar" ]]; then # Download rat launch jar if it hasn't been downloaded yet if [ ! -f "$JAR" ]; then # Download printf "Attempting to fetch rat\n" - JAR_DL=${JAR}.part + JAR_DL="${JAR}.part" if hash curl 2>/dev/null; then - (curl --progress-bar ${URL1} > "$JAR_DL" || curl --progress-bar ${URL2} > "$JAR_DL") && mv "$JAR_DL" "$JAR" + (curl --silent "${URL1}" > "$JAR_DL" || curl --silent "${URL2}" > "$JAR_DL") && mv "$JAR_DL" "$JAR" elif hash wget 2>/dev/null; then - (wget --progress=bar ${URL1} -O "$JAR_DL" || wget --progress=bar ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR" + (wget --quiet ${URL1} -O "$JAR_DL" || wget --quiet ${URL2} -O "$JAR_DL") && mv "$JAR_DL" "$JAR" else printf "You do not have curl or wget installed, please install rat manually.\n" exit -1 @@ -50,7 +50,7 @@ acquire_rat_jar () { } # Go to the Spark project root directory -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" if test -x "$JAVA_HOME/bin/java"; then @@ -60,17 +60,17 @@ else fi export RAT_VERSION=0.10 -export rat_jar=$FWDIR/lib/apache-rat-${RAT_VERSION}.jar -mkdir -p $FWDIR/lib +export rat_jar="$FWDIR"/lib/apache-rat-${RAT_VERSION}.jar +mkdir -p "$FWDIR"/lib [[ -f "$rat_jar" ]] || acquire_rat_jar || { echo "Download failed. Obtain the rat jar manually and place it at $rat_jar" exit 1 } -$java_cmd -jar $rat_jar -E $FWDIR/.rat-excludes -d $FWDIR > rat-results.txt +$java_cmd -jar "$rat_jar" -E "$FWDIR"/.rat-excludes -d "$FWDIR" > rat-results.txt -ERRORS=$(cat rat-results.txt | grep -e "??") +ERRORS="$(cat rat-results.txt | grep -e "??")" if test ! -z "$ERRORS"; then echo "Could not find Apache license headers in the following files:" diff --git a/dev/lint-python b/dev/lint-python index a1e890faa8fa6..772f856154ae0 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -18,10 +18,10 @@ # SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" -SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)" +SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" -cd $SPARK_ROOT_DIR +cd "$SPARK_ROOT_DIR" # Get pep8 at runtime so that we don't rely on it being installed on the build server. #+ See: https://github.com/apache/spark/pull/1744#issuecomment-50982162 @@ -30,6 +30,7 @@ cd $SPARK_ROOT_DIR #+ - Download this from a more reliable source. (GitHub raw can be flaky, apparently. (?)) PEP8_SCRIPT_PATH="$SPARK_ROOT_DIR/dev/pep8.py" PEP8_SCRIPT_REMOTE_PATH="https://raw.githubusercontent.com/jcrocholl/pep8/1.5.7/pep8.py" +PEP8_PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" curl --silent -o "$PEP8_SCRIPT_PATH" "$PEP8_SCRIPT_REMOTE_PATH" curl_status=$? @@ -44,7 +45,7 @@ fi #+ first, but we do so so that the check status can #+ be output before the report, like with the #+ scalastyle and RAT checks. -python $PEP8_SCRIPT_PATH ./python/pyspark > "$PEP8_REPORT_PATH" +python "$PEP8_SCRIPT_PATH" $PEP8_PATHS_TO_CHECK > "$PEP8_REPORT_PATH" pep8_status=${PIPESTATUS[0]} #$? if [ $pep8_status -ne 0 ]; then @@ -54,7 +55,7 @@ else echo "PEP 8 checks passed." fi -rm -f "$PEP8_REPORT_PATH" +rm "$PEP8_REPORT_PATH" rm "$PEP8_SCRIPT_PATH" exit $pep8_status diff --git a/dev/mima b/dev/mima index 09e4482af5f3d..f9b9b03538f15 100755 --- a/dev/mima +++ b/dev/mima @@ -21,12 +21,12 @@ set -o pipefail set -e # Go to the Spark project root directory -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" cd "$FWDIR" echo -e "q\n" | sbt/sbt oldDeps/update -export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"` +export SPARK_CLASSPATH="`find lib_managed \( -name '*spark*jar' -a -type f \) | tr "\\n" ":"`" echo "SPARK_CLASSPATH=$SPARK_CLASSPATH" ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore diff --git a/dev/run-tests b/dev/run-tests index d751961605dfd..49a88085c80f7 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -18,7 +18,7 @@ # # Go to the Spark project root directory -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" if [ -n "$AMPLAB_JENKINS_BUILD_PROFILE" ]; then @@ -89,7 +89,7 @@ echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" -# Build Spark; we always build with Hive because the PySpark SparkSQL tests need it. +# Build Spark; we always build with Hive because the PySpark Spark SQL tests need it. # echo "q" is needed because sbt on encountering a build file with failure # (either resolution or compilation) prompts the user for input either q, r, # etc to quit or retry. This echo is there to make it not block. diff --git a/dev/scalastyle b/dev/scalastyle index eb9b467965636..efb5f291ea3b7 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -19,7 +19,7 @@ echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt # Check style with YARN alpha built too -echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ +echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \ >> scalastyle.txt # Check style with YARN built too echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \ diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 2dbbbf6feb4b8..3b02e090aec28 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -25,8 +25,8 @@ curr_dir = pwd cd("..") - puts "Running 'sbt/sbt compile unidoc' from " + pwd + "; this may take a few minutes..." - puts `sbt/sbt compile unidoc` + puts "Running 'sbt/sbt -Pkinesis-asl compile unidoc' from " + pwd + "; this may take a few minutes..." + puts `sbt/sbt -Pkinesis-asl compile unidoc` puts "Moving back into docs dir." cd("docs") diff --git a/docs/configuration.md b/docs/configuration.md index 65a422caabb7e..36178efb97103 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -293,12 +293,11 @@ Apart from these, the following properties are also available, and may be useful spark.shuffle.manager - HASH + sort - Implementation to use for shuffling data. A hash-based shuffle manager is the default, but - starting in Spark 1.1 there is an experimental sort-based shuffle manager that is more - memory-efficient in environments with small executors, such as YARN. To use that, change - this value to SORT. + Implementation to use for shuffling data. There are two implementations available: + sort and hash. Sort-based shuffle is more memory-efficient and is + the default option starting in 1.2. diff --git a/docs/img/streaming-arch.png b/docs/img/streaming-arch.png index bc57b460fdf8b..ac35f1d34cf3d 100644 Binary files a/docs/img/streaming-arch.png and b/docs/img/streaming-arch.png differ diff --git a/docs/img/streaming-figures.pptx b/docs/img/streaming-figures.pptx index 1b18c2ee0ea3e..d1cc25e379f46 100644 Binary files a/docs/img/streaming-figures.pptx and b/docs/img/streaming-figures.pptx differ diff --git a/docs/img/streaming-kinesis-arch.png b/docs/img/streaming-kinesis-arch.png new file mode 100644 index 0000000000000..bea5fa88df985 Binary files /dev/null and b/docs/img/streaming-kinesis-arch.png differ diff --git a/docs/index.md b/docs/index.md index 4ac0982ae54f1..7fe6b43d32af7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -103,6 +103,8 @@ options for deployment: * [Security](security.html): Spark security support * [Hardware Provisioning](hardware-provisioning.html): recommendations for cluster hardware * [3rd Party Hadoop Distributions](hadoop-third-party-distributions.html): using common Hadoop distributions +* Integration with other storage systems: + * [OpenStack Swift](storage-openstack-swift.html) * [Building Spark with Maven](building-with-maven.html): build Spark using the Maven system * [Contributing to Spark](https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 6ae780d94046a..624cc744dfd51 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -385,7 +385,7 @@ Apart from text files, Spark's Python API also supports several other data forma * SequenceFile and Hadoop Input/Output Formats -**Note** this feature is currently marked ```Experimental``` and is intended for advanced users. It may be replaced in future with read/write support based on SparkSQL, in which case SparkSQL is the preferred approach. +**Note** this feature is currently marked ```Experimental``` and is intended for advanced users. It may be replaced in future with read/write support based on Spark SQL, in which case Spark SQL is the preferred approach. **Writable Support** diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 8f7fb5431cfb6..1814fef465cac 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -68,6 +68,16 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext.createSchemaRDD {% endhighlight %} +In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict +super set of the functionality provided by the basic SQLContext. Additional features include +the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the +ability to read data from Hive tables. To use a HiveContext, you do not need to have an +existing hive setup, and all of the data sources available to a SQLContext are still available. +HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default +Spark build. If these dependencies are not a problem for your application then using HiveContext +is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to +feature parity with a HiveContext. +
@@ -81,6 +91,16 @@ JavaSparkContext sc = ...; // An existing JavaSparkContext. JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); {% endhighlight %} +In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict +super set of the functionality provided by the basic SQLContext. Additional features include +the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the +ability to read data from Hive tables. To use a HiveContext, you do not need to have an +existing hive setup, and all of the data sources available to a SQLContext are still available. +HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default +Spark build. If these dependencies are not a problem for your application then using HiveContext +is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to +feature parity with a HiveContext. +
@@ -94,36 +114,52 @@ from pyspark.sql import SQLContext sqlContext = SQLContext(sc) {% endhighlight %} -
+In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict +super set of the functionality provided by the basic SQLContext. Additional features include +the ability to write queries using the more complete HiveQL parser, access to HiveUDFs, and the +ability to read data from Hive tables. To use a HiveContext, you do not need to have an +existing hive setup, and all of the data sources available to a SQLContext are still available. +HiveContext is only packaged separately to avoid including all of Hive's dependencies in the default +Spark build. If these dependencies are not a problem for your application then using HiveContext +is recommended for the 1.2 release of Spark. Future releases will focus on bringing SQLContext up to +feature parity with a HiveContext. -# Data Sources - -
-
-Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. -Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources.
-
-Spark SQL supports operating on a variety of data sources through the `JavaSchemaRDD` interface. -Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. -
+The specific variant of SQL that is used to parse queries can also be selected using the +`spark.sql.dialect` option. This parameter can be changed using either the `setConf` method on +a SQLContext or by using a `SET key=value` command in SQL. For a SQLContext, the only dialect +available is "sql" which uses a simple SQL parser provided by Spark SQL. In a HiveContext, the +default is "hiveql", though "sql" is also available. Since the HiveQL parser is much more complete, + this is recommended for most use cases. + +# Data Sources -
Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface. -Once a dataset has been loaded, it can be registered as a table and even joined with data from other sources. -
-
+A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table. +Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section +describes the various methods for loading data into a SchemaRDD. ## RDDs +Spark SQL supports two different methods for converting existing RDDs into SchemaRDDs. The first +method uses reflection to infer the schema of an RDD that contains specific types of objects. This +reflection based approach leads to more concise code and works well went the schema is known ahead +of time, while you are writing your Spark application. + +The second method for creating SchemaRDDs is through a programmatic interface that allows you to +construct a schema and then apply it to and existing RDD. While this method is more verbose, it allows +you to construct SchemaRDDs when the columns and their types are not known until runtime. + +### Inferring the Schema Using Reflection
-One type of table that is supported by Spark SQL is an RDD of Scala case classes. The case class +The Scala interaface for Spark SQL supports automatically converting an RDD containing case classes +to a SchemaRDD. The case class defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be @@ -156,8 +192,9 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
-One type of table that is supported by Spark SQL is an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly). The BeanInfo -defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain +Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly) +into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table. +Currently, Spark SQL does not support JavaBeans that contain nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a class that implements Serializable and has getters and setters for all of its fields. @@ -192,7 +229,7 @@ for the JavaBean. {% highlight java %} // sc is an existing JavaSparkContext. -JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc) +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); // Load a text file and convert each line to a JavaBean. JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map( @@ -229,24 +266,24 @@ List teenagerNames = teenagers.map(new Function() {
-One type of table that is supported by Spark SQL is an RDD of dictionaries. The keys of the -dictionary define the columns names of the table, and the types are inferred by looking at the first -row. Any RDD of dictionaries can converted to a SchemaRDD and then registered as a table. Tables -can be used in subsequent SQL statements. +Spark SQL can convert an RDD of Row objects to a SchemaRDD, inferring the datatypes . Rows are constructed by passing a list of +key/value pairs as kwargs to the Row class. The keys of this list define the columns names of the table, +and the types are inferred by looking at the first row. Since we currently only look at the first +row, it is important that there is no missing data in the first row of the RDD. In future version we +plan to more completely infer the schema by looking at more data, similar to the inference that is +performed on JSON files. {% highlight python %} # sc is an existing SparkContext. -from pyspark.sql import SQLContext +from pyspark.sql import SQLContext, Row sqlContext = SQLContext(sc) # Load a text file and convert each line to a dictionary. lines = sc.textFile("examples/src/main/resources/people.txt") parts = lines.map(lambda l: l.split(",")) -people = parts.map(lambda p: {"name": p[0], "age": int(p[1])}) +people = parts.map(lambda p: Row(name=p[0], age=int(p[1]))) # Infer the schema, and register the SchemaRDD as a table. -# In future versions of PySpark we would like to add support for registering RDDs with other -# datatypes as tables schemaPeople = sqlContext.inferSchema(people) schemaPeople.registerTempTable("people") @@ -263,15 +300,191 @@ for teenName in teenNames.collect():
-**Note that Spark SQL currently uses a very basic SQL parser.** -Users that want a more complete dialect of SQL should look at the HiveQL support provided by -`HiveContext`. +### Programmatically Specifying the Schema + +
+ +
+ +In cases that case classes cannot be defined ahead of time (for example, +the structure of records is encoded in a string or a text dataset will be parsed +and fields will be projected differently for different users), +a `SchemaRDD` can be created programmatically with three steps. + +1. Create an RDD of `Row`s from the original RDD; +2. Create the schema represented by a `StructType` matching the structure of +`Row`s in the RDD created in the step 1. +3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +by `SQLContext`. + +For example: +{% highlight scala %} +// sc is an existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) + +// Create an RDD +val people = sc.textFile("examples/src/main/resources/people.txt") + +// The schema is encoded in a string +val schemaString = "name age" + +// Import Spark SQL data types and Row. +import org.apache.spark.sql._ + +// Generate the schema based on the string of schema +val schema = + StructType( + schemaString.split(" ").map(fieldName => StructField(fieldName, StringType, true))) + +// Convert records of the RDD (people) to Rows. +val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim)) + +// Apply the schema to the RDD. +val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema) + +// Register the SchemaRDD as a table. +peopleSchemaRDD.registerTempTable("people") + +// SQL statements can be run by using the sql methods provided by sqlContext. +val results = sqlContext.sql("SELECT name FROM people") + +// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The columns of a row in the result can be accessed by ordinal. +results.map(t => "Name: " + t(0)).collect().foreach(println) +{% endhighlight %} + + +
+ +
+ +In cases that JavaBean classes cannot be defined ahead of time (for example, +the structure of records is encoded in a string or a text dataset will be parsed and +fields will be projected differently for different users), +a `SchemaRDD` can be created programmatically with three steps. + +1. Create an RDD of `Row`s from the original RDD; +2. Create the schema represented by a `StructType` matching the structure of +`Row`s in the RDD created in the step 1. +3. Apply the schema to the RDD of `Row`s via `applySchema` method provided +by `JavaSQLContext`. + +For example: +{% highlight java %} +// Import factory methods provided by DataType. +import org.apache.spark.sql.api.java.DataType +// Import StructType and StructField +import org.apache.spark.sql.api.java.StructType +import org.apache.spark.sql.api.java.StructField +// Import Row. +import org.apache.spark.sql.api.java.Row + +// sc is an existing JavaSparkContext. +JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc); + +// Load a text file and convert each line to a JavaBean. +JavaRDD people = sc.textFile("examples/src/main/resources/people.txt"); + +// The schema is encoded in a string +String schemaString = "name age"; + +// Generate the schema based on the string of schema +List fields = new ArrayList(); +for (String fieldName: schemaString.split(" ")) { + fields.add(DataType.createStructField(fieldName, DataType.StringType, true)); +} +StructType schema = DataType.createStructType(fields); + +// Convert records of the RDD (people) to Rows. +JavaRDD rowRDD = people.map( + new Function() { + public Row call(String record) throws Exception { + String[] fields = record.split(","); + return Row.create(fields[0], fields[1].trim()); + } + }); + +// Apply the schema to the RDD. +JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema); + +// Register the SchemaRDD as a table. +peopleSchemaRDD.registerTempTable("people"); + +// SQL can be run over RDDs that have been registered as tables. +JavaSchemaRDD results = sqlContext.sql("SELECT name FROM people"); + +// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The columns of a row in the result can be accessed by ordinal. +List names = results.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0); + } +}).collect(); + +{% endhighlight %} + +
+ +
+ +For some cases (for example, the structure of records is encoded in a string or +a text dataset will be parsed and fields will be projected differently for +different users), it is desired to create `SchemaRDD` with a programmatically way. +It can be done with three steps. + +1. Create an RDD of tuples or lists from the original RDD; +2. Create the schema represented by a `StructType` matching the structure of +tuples or lists in the RDD created in the step 1. +3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`. + +For example: +{% highlight python %} +# Import SQLContext and data types +from pyspark.sql import * + +# sc is an existing SparkContext. +sqlContext = SQLContext(sc) + +# Load a text file and convert each line to a tuple. +lines = sc.textFile("examples/src/main/resources/people.txt") +parts = lines.map(lambda l: l.split(",")) +people = parts.map(lambda p: (p[0], p[1].strip())) + +# The schema is encoded in a string. +schemaString = "name age" + +fields = [StructField(field_name, StringType(), True) for field_name in schemaString.split()] +schema = StructType(fields) + +# Apply the schema to the RDD. +schemaPeople = sqlContext.applySchema(people, schema) + +# Register the SchemaRDD as a table. +schemaPeople.registerTempTable("people") + +# SQL can be run over SchemaRDDs that have been registered as a table. +results = sqlContext.sql("SELECT name FROM people") + +# The results of SQL queries are RDDs and support all the normal RDD operations. +names = results.map(lambda p: "Name: " + p.name) +for name in names.collect(): + print name +{% endhighlight %} + + +
+ +
## Parquet Files [Parquet](http://parquet.io) is a columnar format that is supported by many other data processing systems. Spark SQL provides support for both reading and writing Parquet files that automatically preserves the schema -of the original data. Using the data from the above example: +of the original data. + +### Loading Data Programmatically + +Using the data from the above example:
@@ -349,7 +562,40 @@ for teenName in teenNames.collect():
-
+
+ +### Configuration + +Configuration of parquet can be done using the `setConf` method on SQLContext or by running +`SET key=value` commands using SQL. + + + + + + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.parquet.binaryAsStringfalse + Some other parquet producing systems, in particular Impala and older versions of Spark SQL, do + not differentiate between binary data and strings when writing out the parquet schema. This + flag tells Spark SQL to interpret binary data as a string to provide compatibility with these systems. +
spark.sql.parquet.cacheMetadatafalse + Turns on caching of parquet schema metadata. Can speed up querying +
spark.sql.parquet.compression.codecsnappy + Sets the compression codec use when writing parquet files. Acceptable values include: + uncompressed, snappy, gzip, lzo. +
## JSON Datasets
@@ -493,13 +739,13 @@ directory. {% highlight scala %} // sc is an existing SparkContext. -val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc) +val sqlContext = new org.apache.spark.sql.hive.HiveContext(sc) -hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL -hiveContext.sql("FROM src SELECT key, value").collect().foreach(println) +sqlContext.sql("FROM src SELECT key, value").collect().foreach(println) {% endhighlight %}
@@ -513,13 +759,13 @@ expressed in HiveQL. {% highlight java %} // sc is an existing JavaSparkContext. -JavaHiveContext hiveContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); +JavaHiveContext sqlContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc); -hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); -hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); +sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); +sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); // Queries are expressed in HiveQL. -Row[] results = hiveContext.sql("FROM src SELECT key, value").collect(); +Row[] results = sqlContext.sql("FROM src SELECT key, value").collect(); {% endhighlight %} @@ -535,44 +781,97 @@ expressed in HiveQL. {% highlight python %} # sc is an existing SparkContext. from pyspark.sql import HiveContext -hiveContext = HiveContext(sc) +sqlContext = HiveContext(sc) -hiveContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -hiveContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") # Queries can be expressed in HiveQL. -results = hiveContext.sql("FROM src SELECT key, value").collect() +results = sqlContext.sql("FROM src SELECT key, value").collect() {% endhighlight %}
-# Writing Language-Integrated Relational Queries +# Performance Tuning -**Language-Integrated queries are currently only supported in Scala.** - -Spark SQL also supports a domain specific language for writing queries. Once again, -using the data from the above examples: +For some workloads it is possible to improve performance by either caching data in memory, or by +turning on some experimental options. -{% highlight scala %} -// sc is an existing SparkContext. -val sqlContext = new org.apache.spark.sql.SQLContext(sc) -// Importing the SQL context gives access to all the public SQL functions and implicit conversions. -import sqlContext._ -val people: RDD[Person] = ... // An RDD of case class objects, from the first example. +## Caching Data In Memory -// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' -val teenagers = people.where('age >= 10).where('age <= 19).select('name) -teenagers.map(t => "Name: " + t(0)).collect().foreach(println) -{% endhighlight %} +Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`. +Then Spark SQL will scan only required columns and will automatically tune compression to minimize +memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory. -The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers -prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are -evaluated by the SQL execution engine. A full list of the functions supported can be found in the -[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). +Note that if you just call `cache` rather than `cacheTable`, tables will _not_ be cached in +in-memory columnar format. So we strongly recommend using `cacheTable` whenever you want to +cache tables. - +Configuration of in-memory caching can be done using the `setConf` method on SQLContext or by running +`SET key=value` commands using SQL. + + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.inMemoryColumnarStorage.compressedfalse + When set to true Spark SQL will automatically select a compression codec for each column based + on statistics of the data. +
spark.sql.inMemoryColumnarStorage.batchSize1000 + Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization + and compression, but risk OOMs when caching data. +
+ +## Other Configuration + +The following options can also be used to tune the performance of query execution. It is possible +that these options will be deprecated in future release as more optimizations are performed automatically. + + + + + + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.sql.autoBroadcastJoinThresholdfalse + Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when + performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently + statistics are only supported for Hive Metastore tables where the command + `ANALYZE TABLE <tableName> COMPUTE STATISTICS noscan` has been run. +
spark.sql.codegenfalse + When true, code will be dynamically generated at runtime for expression evaluation in a specific + query. For some queries with complicated expression this option can lead to significant speed-ups. + However, for simple queries this can actually slow down query execution. +
spark.sql.shuffle.partitions200 + Configures the number of partitions to use when shuffling data for joins or aggregations. +
+ +# Other SQL Interfaces + +Spark SQL also supports interfaces for running SQL queries directly without the need to write any +code. ## Running the Thrift JDBC server @@ -602,14 +901,28 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. You may also use the beeline script comes with Hive. +## Running the Spark SQL CLI + +The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute +queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server. + +To start the Spark SQL CLI, run the following in the Spark directory: + + ./bin/spark-sql + +Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +You may run `./bin/spark-sql --help` for a complete list of all available +options. + +# Compatibility with Other Systems + +## Migration Guide for Shark Users To set a [Fair Scheduler](job-scheduling.html#fair-scheduler-pools) pool for a JDBC client session, users can set the `spark.sql.thriftserver.scheduler.pool` variable: SET spark.sql.thriftserver.scheduler.pool=accounting; -### Migration Guide for Shark Users - -#### Reducer number +### Reducer number In Shark, default reducer number is 1 and is controlled by the property `mapred.reduce.tasks`. Spark SQL deprecates this property by a new property `spark.sql.shuffle.partitions`, whose default value @@ -625,7 +938,7 @@ You may also put this property in `hive-site.xml` to override the default value. For now, the `mapred.reduce.tasks` property is still recognized, and is converted to `spark.sql.shuffle.partitions` automatically. -#### Caching +### Caching The `shark.cache` table property no longer exists, and tables whose name end with `_cached` are no longer automatically cached. Instead, we provide `CACHE TABLE` and `UNCACHE TABLE` statements to @@ -634,9 +947,9 @@ let user control table caching explicitly: CACHE TABLE logs_last_month; UNCACHE TABLE logs_last_month; -**NOTE:** `CACHE TABLE tbl` is lazy, it only marks table `tbl` as "need to by cached if necessary", -but doesn't actually cache it until a query that touches `tbl` is executed. To force the table to be -cached, you may simply count the table immediately after executing `CACHE TABLE`: +**NOTE:** `CACHE TABLE tbl` is lazy, similar to `.cache` on an RDD. This command only marks `tbl` to ensure that +partitions are cached when calculated but doesn't actually cache it until a query that touches `tbl` is executed. +To force the table to be cached, you may simply count the table immediately after executing `CACHE TABLE`: CACHE TABLE logs_last_month; SELECT COUNT(1) FROM logs_last_month; @@ -647,15 +960,18 @@ Several caching related features are not supported yet: * RDD reloading * In-memory cache write through policy -### Compatibility with Apache Hive +## Compatibility with Apache Hive + +Spark SQL is designed to be compatible with the Hive Metastore, SerDes and UDFs. Currently Spark +SQL is based on Hive 0.12.0. #### Deploying in Existing Hive Warehouses -Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive +The Spark SQL Thrift JDBC server is designed to be "out of the box" compatible with existing Hive installations. You do not need to modify your existing Hive Metastore or change the data placement or partitioning of your tables. -#### Supported Hive Features +### Supported Hive Features Spark SQL supports the vast majority of Hive features, such as: @@ -705,13 +1021,14 @@ Spark SQL supports the vast majority of Hive features, such as: * `MAP<>` * `STRUCT<>` -#### Unsupported Hive Functionality +### Unsupported Hive Functionality Below is a list of Hive features that we don't support yet. Most of these features are rarely used in Hive deployments. **Major Hive Features** +* Spark SQL does not currently support inserting to tables using dynamic partitioning. * Tables with buckets: bucket is the hash partitioning within a Hive table partition. Spark SQL doesn't support buckets yet. @@ -721,11 +1038,11 @@ in Hive deployments. have the same input format. * Non-equi outer join: For the uncommon use case of using outer joins with non-equi join conditions (e.g. condition "`key < 10`"), Spark SQL will output wrong result for the `NULL` tuple. -* `UNIONTYPE` +* `UNION` type and `DATE` type * Unique join * Single query multi insert * Column statistics collecting: Spark SQL does not piggyback scans to collect column statistics at - the moment. + the moment and only supports populating the sizeInBytes field of the hive metastore. **Hive Input/Output Formats** @@ -735,7 +1052,7 @@ in Hive deployments. **Hive Optimizations** A handful of Hive optimizations are not yet included in Spark. Some of these (such as indexes) are -not necessary due to Spark SQL's in-memory computational model. Others are slotted for future +less important due to Spark SQL's in-memory computational model. Others are slotted for future releases of Spark SQL. * Block level bitmap indexes and virtual columns (used to build indexes) @@ -743,8 +1060,7 @@ releases of Spark SQL. Hive automatically converts the join into a map join. We are adding this auto conversion in the next release. * Automatically determine the number of reducers for joins and groupbys: Currently in Spark SQL, you - need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". We are going to add auto-setting of parallelism in the - next release. + need to control the degree of parallelism post-shuffle using "`SET spark.sql.shuffle.partitions=[num_tasks];`". * Meta-data only query: For queries that can be answered by using only meta data, Spark SQL still launches tasks to compute the result. * Skew data flag: Spark SQL does not follow the skew data flags in Hive. @@ -753,25 +1069,471 @@ releases of Spark SQL. Hive can optionally merge the small files into fewer large files to avoid overflowing the HDFS metadata. Spark SQL does not support that. -## Running the Spark SQL CLI +# Writing Language-Integrated Relational Queries -The Spark SQL CLI is a convenient tool to run the Hive metastore service in local mode and execute -queries input from command line. Note: the Spark SQL CLI cannot talk to the Thrift JDBC server. +**Language-Integrated queries are experimental and currently only supported in Scala.** -To start the Spark SQL CLI, run the following in the Spark directory: +Spark SQL also supports a domain specific language for writing queries. Once again, +using the data from the above examples: - ./bin/spark-sql +{% highlight scala %} +// sc is an existing SparkContext. +val sqlContext = new org.apache.spark.sql.SQLContext(sc) +// Importing the SQL context gives access to all the public SQL functions and implicit conversions. +import sqlContext._ +val people: RDD[Person] = ... // An RDD of case class objects, from the first example. -Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. -You may run `./bin/spark-sql --help` for a complete list of all available -options. +// The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' +val teenagers = people.where('age >= 10).where('age <= 19).select('name) +teenagers.map(t => "Name: " + t(0)).collect().foreach(println) +{% endhighlight %} -# Cached tables +The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers +prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are +evaluated by the SQL execution engine. A full list of the functions supported can be found in the +[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD). -Spark SQL can cache tables using an in-memory columnar format by calling `cacheTable("tableName")`. -Then Spark SQL will scan only required columns and will automatically tune compression to minimize -memory usage and GC pressure. You can call `uncacheTable("tableName")` to remove the table from memory. + + +# Spark SQL DataType Reference + +* Numeric types + - `ByteType`: Represents 1-byte signed integer numbers. + The range of numbers is from `-128` to `127`. + - `ShortType`: Represents 2-byte signed integer numbers. + The range of numbers is from `-32768` to `32767`. + - `IntegerType`: Represents 4-byte signed integer numbers. + The range of numbers is from `-2147483648` to `2147483647`. + - `LongType`: Represents 8-byte signed integer numbers. + The range of numbers is from `-9223372036854775808` to `9223372036854775807`. + - `FloatType`: Represents 4-byte single-precision floating point numbers. + - `DoubleType`: Represents 8-byte double-precision floating point numbers. + - `DecimalType`: +* String type + - `StringType`: Represents character string values. +* Binary type + - `BinaryType`: Represents byte sequence values. +* Boolean type + - `BooleanType`: Represents boolean values. +* Datetime type + - `TimestampType`: Represents values comprising values of fields year, month, day, + hour, minute, and second. +* Complex types + - `ArrayType(elementType, containsNull)`: Represents values comprising a sequence of + elements with the type of `elementType`. `containsNull` is used to indicate if + elements in a `ArrayType` value can have `null` values. + - `MapType(keyType, valueType, valueContainsNull)`: + Represents values comprising a set of key-value pairs. The data type of keys are + described by `keyType` and the data type of values are described by `valueType`. + For a `MapType` value, keys are not allowed to have `null` values. `valueContainsNull` + is used to indicate if values of a `MapType` value can have `null` values. + - `StructType(fields)`: Represents values with the structure described by + a sequence of `StructField`s (`fields`). + * `StructField(name, dataType, nullable)`: Represents a field in a `StructType`. + The name of a field is indicated by `name`. The data type of a field is indicated + by `dataType`. `nullable` is used to indicate if values of this fields can have + `null` values. + +
+
+ +All data types of Spark SQL are located in the package `org.apache.spark.sql`. +You can access them by doing +{% highlight scala %} +import org.apache.spark.sql._ +{% endhighlight %} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Data typeValue type in ScalaAPI to access or create a data type
ByteType Byte + ByteType +
ShortType Short + ShortType +
IntegerType Int + IntegerType +
LongType Long + LongType +
FloatType Float + FloatType +
DoubleType Double + DoubleType +
DecimalType scala.math.sql.BigDecimal + DecimalType +
StringType String + StringType +
BinaryType Array[Byte] + BinaryType +
BooleanType Boolean + BooleanType +
TimestampType java.sql.Timestamp + TimestampType +
ArrayType scala.collection.Seq + ArrayType(elementType, [containsNull])
+ Note: The default value of containsNull is false. +
MapType scala.collection.Map + MapType(keyType, valueType, [valueContainsNull])
+ Note: The default value of valueContainsNull is true. +
StructType org.apache.spark.sql.Row + StructType(fields)
+ Note: fields is a Seq of StructFields. Also, two fields with the same + name are not allowed. +
StructField The value type in Scala of the data type of this field + (For example, Int for a StructField with the data type IntegerType) + StructField(name, dataType, nullable) +
+ +
+ +
+ +All data types of Spark SQL are located in the package of +`org.apache.spark.sql.api.java`. To access or create a data type, +please use factory methods provided in +`org.apache.spark.sql.api.java.DataType`. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Data typeValue type in JavaAPI to access or create a data type
ByteType byte or Byte + DataType.ByteType +
ShortType short or Short + DataType.ShortType +
IntegerType int or Integer + DataType.IntegerType +
LongType long or Long + DataType.LongType +
FloatType float or Float + DataType.FloatType +
DoubleType double or Double + DataType.DoubleType +
DecimalType java.math.BigDecimal + DataType.DecimalType +
StringType String + DataType.StringType +
BinaryType byte[] + DataType.BinaryType +
BooleanType boolean or Boolean + DataType.BooleanType +
TimestampType java.sql.Timestamp + DataType.TimestampType +
ArrayType java.util.List + DataType.createArrayType(elementType)
+ Note: The value of containsNull will be false
+ DataType.createArrayType(elementType, containsNull). +
MapType java.util.Map + DataType.createMapType(keyType, valueType)
+ Note: The value of valueContainsNull will be true.
+ DataType.createMapType(keyType, valueType, valueContainsNull)
+
StructType org.apache.spark.sql.api.java + DataType.createStructType(fields)
+ Note: fields is a List or an array of StructFields. + Also, two fields with the same name are not allowed. +
StructField The value type in Java of the data type of this field + (For example, int for a StructField with the data type IntegerType) + DataType.createStructField(name, dataType, nullable) +
+ +
+ +
+ +All data types of Spark SQL are located in the package of `pyspark.sql`. +You can access them by doing +{% highlight python %} +from pyspark.sql import * +{% endhighlight %} + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Data typeValue type in PythonAPI to access or create a data type
ByteType + int or long
+ Note: Numbers will be converted to 1-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -128 to 127. +
+ ByteType() +
ShortType + int or long
+ Note: Numbers will be converted to 2-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of -32768 to 32767. +
+ ShortType() +
IntegerType int or long + IntegerType() +
LongType + long
+ Note: Numbers will be converted to 8-byte signed integer numbers at runtime. + Please make sure that numbers are within the range of + -9223372036854775808 to 9223372036854775807. + Otherwise, please convert data to decimal.Decimal and use DecimalType. +
+ LongType() +
FloatType + float
+ Note: Numbers will be converted to 4-byte single-precision floating + point numbers at runtime. +
+ FloatType() +
DoubleType float + DoubleType() +
DecimalType decimal.Decimal + DecimalType() +
StringType string + StringType() +
BinaryType bytearray + BinaryType() +
BooleanType bool + BooleanType() +
TimestampType datetime.datetime + TimestampType() +
ArrayType list, tuple, or array + ArrayType(elementType, [containsNull])
+ Note: The default value of containsNull is False. +
MapType dict + MapType(keyType, valueType, [valueContainsNull])
+ Note: The default value of valueContainsNull is True. +
StructType list or tuple + StructType(fields)
+ Note: fields is a Seq of StructFields. Also, two fields with the same + name are not allowed. +
StructField The value type in Python of the data type of this field + (For example, Int for a StructField with the data type IntegerType) + StructField(name, dataType, nullable) +
+ +
+ +
-Note that if you just call `cache` rather than `cacheTable`, tables will _not_ be cached in -in-memory columnar format. So we strongly recommend using `cacheTable` whenever you want to -cache tables. diff --git a/docs/storage-openstack-swift.md b/docs/storage-openstack-swift.md new file mode 100644 index 0000000000000..c39ef1ce59e1c --- /dev/null +++ b/docs/storage-openstack-swift.md @@ -0,0 +1,152 @@ +--- +layout: global +title: Accessing OpenStack Swift from Spark +--- + +Spark's support for Hadoop InputFormat allows it to process data in OpenStack Swift using the +same URI formats as in Hadoop. You can specify a path in Swift as input through a +URI of the form swift://container.PROVIDER/path. You will also need to set your +Swift security credentials, through core-site.xml or via +SparkContext.hadoopConfiguration. +Current Swift driver requires Swift to use Keystone authentication method. + +# Configuring Swift for Better Data Locality + +Although not mandatory, it is recommended to configure the proxy server of Swift with +list_endpoints to have better data locality. More information is +[available here](https://github.com/openstack/swift/blob/master/swift/common/middleware/list_endpoints.py). + + +# Dependencies + +The Spark application should include hadoop-openstack dependency. +For example, for Maven support, add the following to the pom.xml file: + +{% highlight xml %} + + ... + + org.apache.hadoop + hadoop-openstack + 2.3.0 + + ... + +{% endhighlight %} + + +# Configuration Parameters + +Create core-site.xml and place it inside Spark's conf directory. +There are two main categories of parameters that should to be configured: declaration of the +Swift driver and the parameters that are required by Keystone. + +Configuration of Hadoop to use Swift File system achieved via + + + + + + + +
Property NameValue
fs.swift.implorg.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem
+ +Additional parameters required by Keystone (v2.0) and should be provided to the Swift driver. Those +parameters will be used to perform authentication in Keystone to access Swift. The following table +contains a list of Keystone mandatory parameters. PROVIDER can be any name. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Property NameMeaningRequired
fs.swift.service.PROVIDER.auth.urlKeystone Authentication URLMandatory
fs.swift.service.PROVIDER.auth.endpoint.prefixKeystone endpoints prefixOptional
fs.swift.service.PROVIDER.tenantTenantMandatory
fs.swift.service.PROVIDER.usernameUsernameMandatory
fs.swift.service.PROVIDER.passwordPasswordMandatory
fs.swift.service.PROVIDER.http.portHTTP portMandatory
fs.swift.service.PROVIDER.regionKeystone regionMandatory
fs.swift.service.PROVIDER.publicIndicates if all URLs are publicMandatory
+ +For example, assume PROVIDER=SparkTest and Keystone contains user tester with password testing +defined for tenant test. Then core-site.xml should include: + +{% highlight xml %} + + + fs.swift.impl + org.apache.hadoop.fs.swift.snative.SwiftNativeFileSystem + + + fs.swift.service.SparkTest.auth.url + http://127.0.0.1:5000/v2.0/tokens + + + fs.swift.service.SparkTest.auth.endpoint.prefix + endpoints + + fs.swift.service.SparkTest.http.port + 8080 + + + fs.swift.service.SparkTest.region + RegionOne + + + fs.swift.service.SparkTest.public + true + + + fs.swift.service.SparkTest.tenant + test + + + fs.swift.service.SparkTest.username + tester + + + fs.swift.service.SparkTest.password + testing + + +{% endhighlight %} + +Notice that +fs.swift.service.PROVIDER.tenant, +fs.swift.service.PROVIDER.username, +fs.swift.service.PROVIDER.password contains sensitive information and keeping them in +core-site.xml is not always a good approach. +We suggest to keep those parameters in core-site.xml for testing purposes when running Spark +via spark-shell. +For job submissions they should be provided via sparkContext.hadoopConfiguration. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md new file mode 100644 index 0000000000000..d57c3e0ef9ba0 --- /dev/null +++ b/docs/streaming-flume-integration.md @@ -0,0 +1,132 @@ +--- +layout: global +title: Spark Streaming + Flume Integration Guide +--- + +[Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. + +## Approach 1: Flume-style Push-based Approach +Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. + +#### General Requirements +Choose a machine in your cluster such that + +- When your Flume + Spark Streaming application is launched, one of the Spark workers must run on that machine. + +- Flume can be configured to push data to a port on that machine. + +Due to the push model, the streaming application needs to be up, with the receiver scheduled and listening on the chosen port, for Flume to be able push data. + +#### Configuring Flume +Configure Flume agent to send data to an Avro sink by having the following in the configuration file. + + agent.sinks = avroSink + agent.sinks.avroSink.type = avro + agent.sinks.avroSink.channel = memoryChannel + agent.sinks.avroSink.hostname = + agent.sinks.avroSink.port = + +See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about +configuring Flume agents. + +#### Configuring Spark Streaming Application +1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-flume_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +2. **Programming:** In the streaming application code, import `FlumeUtils` and create input DStream as follows. + +
+
+ import org.apache.spark.streaming.flume._ + + val flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) + + See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala). +
+
+ import org.apache.spark.streaming.flume.*; + + JavaReceiverInputDStream flumeStream = + FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]); + + See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). +
+
+ + Note that the hostname should be the same as the one used by the resource manager in the + cluster (Mesos, YARN or Spark Standalone), so that resource allocation can match the names and launch + the receiver in the right machine. + +3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + +## Approach 2 (Experimental): Pull-based Approach using a Custom Sink +Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following. +- Flume pushes data into the sink, and the data stays buffered. +- Spark Streaming uses transactions to pull data from the sink. Transactions succeed only after data is received and replicated by Spark Streaming. +This ensures that better reliability and fault-tolerance than the previous approach. However, this requires configuring Flume to run a custom sink. Here are the configuration steps. + +#### General Requirements +Choose a machine that will run the custom sink in a Flume agent. The rest of the Flume pipeline is configured to send data to that agent. Machines in the Spark cluster should have access to the chosen machine running the custom sink. + +#### Configuring Flume +Configuring Flume on the chosen machine requires the following two steps. + +1. **Sink JARs**: Add the following JARs to Flume's classpath (see [Flume's documentation](https://flume.apache.org/documentation.html) to see how) in the machine designated to run the custom sink . + + (i) *Custom sink JAR*: Download the JAR corresponding to the following artifact (or [direct link](http://search.maven.org/remotecontent?filepath=org/apache/spark/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}/{{site.SPARK_VERSION_SHORT}}/spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}}-{{site.SPARK_VERSION_SHORT}}.jar)). + + groupId = org.apache.spark + artifactId = spark-streaming-flume-sink_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + + (ii) *Scala library JAR*: Download the Scala library JAR for Scala {{site.SCALA_VERSION}}. It can be found with the following artifact detail (or, [direct link](http://search.maven.org/remotecontent?filepath=org/scala-lang/scala-library/{{site.SCALA_VERSION}}/scala-library-{{site.SCALA_VERSION}}.jar)). + + groupId = org.scala-lang + artifactId = scala-library + version = {{site.SCALA_VERSION}} + +2. **Configuration file**: On that machine, configure Flume agent to send data to an Avro sink by having the following in the configuration file. + + agent.sinks = spark + agent.sinks.spark.type = org.apache.spark.streaming.flume.sink.SparkSink + agent.sinks.spark.hostname = + agent.sinks.spark.port = + agent.sinks.spark.channel = memoryChannel + + Also make sure that the upstream Flume pipeline is configured to send the data to the Flume agent running this sink. + +See the [Flume's documentation](https://flume.apache.org/documentation.html) for more information about +configuring Flume agents. + +#### Configuring Spark Streaming Application +1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide). + +2. **Programming:** In the streaming application code, import `FlumeUtils` and create input DStream as follows. + +
+
+ import org.apache.spark.streaming.flume._ + + val flumeStream = FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]) +
+
+ import org.apache.spark.streaming.flume.*; + + JavaReceiverInputDStreamflumeStream = + FlumeUtils.createPollingStream(streamingContext, [sink machine hostname], [sink port]); +
+
+ + See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). + + Note that each input DStream can be configured to receive data from multiple sinks. + +3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + + diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md new file mode 100644 index 0000000000000..a3b705d4c31d0 --- /dev/null +++ b/docs/streaming-kafka-integration.md @@ -0,0 +1,42 @@ +--- +layout: global +title: Spark Streaming + Kafka Integration Guide +--- +[Apache Kafka](http://kafka.apache.org/) is publish-subscribe messaging rethought as a distributed, partitioned, replicated commit log service. Here we explain how to configure Spark Streaming to receive data from Kafka. + +1. **Linking:** In your SBT/Maven projrect definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + +2. **Programming:** In the streaming application code, import `KafkaUtils` and create input DStream as follows. + +
+
+ import org.apache.spark.streaming.kafka._ + + val kafkaStream = KafkaUtils.createStream( + streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]) + + See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala). +
+
+ import org.apache.spark.streaming.kafka.*; + + JavaPairReceiverInputDStream kafkaStream = KafkaUtils.createStream( + streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]); + + See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java). +
+
+ + *Points to remember:* + + - Topic partitions in Kafka does not correlate to partitions of RDDs generated in Spark Streaming. So increasing the number of topic-specific partitions in the `KafkaUtils.createStream()` only increases the number of threads using which topics that are consumed within a single receiver. It does not increase the parallelism of Spark in processing the data. Refer to the main document for more information on that. + + - Multiple Kafka input DStreams can be created with different groups and topics for parallel receiving of data using multiple receivers. + +3. **Deploying:** Package `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md new file mode 100644 index 0000000000000..c6090d9ec30c7 --- /dev/null +++ b/docs/streaming-kinesis-integration.md @@ -0,0 +1,150 @@ +--- +layout: global +title: Spark Streaming + Kinesis Integration +--- +[Amazon Kinesis](http://aws.amazon.com/kinesis/) is a fully managed service for real-time processing of streaming data at massive scale. +The Kinesis receiver creates an input DStream using the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License (ASL). +The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concepts of Workers, Checkpoints, and Shard Leases. +Here we explain how to configure Spark Streaming to receive data from Kinesis. + +#### Configuring Kinesis + +A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or more shards per the following +[guide](http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html). + + +#### Configuring Spark Streaming Application + +1. **Linking:** In your SBT/Maven project definition, link your streaming application against the following artifact (see [Linking section](streaming-programming-guide.html#linking) in the main programming guide for further information). + + groupId = org.apache.spark + artifactId = spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} + version = {{site.SPARK_VERSION_SHORT}} + + **Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your application.** + +2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream as follows: + +
+
+ import org.apache.spark.streaming.Duration + import org.apache.spark.streaming.kinesis._ + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + + val kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]) + + See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example. + +
+
+ import org.apache.spark.streaming.Duration; + import org.apache.spark.streaming.kinesis.*; + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + + JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis stream name], [endpoint URL], [checkpoint interval], [initial position]); + + See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + +
+
+ + - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream + + - `[Kinesis stream name]`: The Kinesis stream that this streaming application receives from + - The application name used in the streaming context becomes the Kinesis application name + - The application name must be unique for a given account and region. + - The Kinesis backend automatically associates the application name to the Kinesis stream using a DynamoDB table (always in the us-east-1 region) created during Kinesis Client Library initialization. + - Changing the application name or stream name can lead to Kinesis errors in some cases. If you see errors, you may need to manually delete the DynamoDB table. + + + - `[endpoint URL]`: Valid Kinesis endpoints URL can be found [here](http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region). + + - `[checkpoint interval]`: The interval (e.g., Duration(2000) = 2 seconds) at which the Kinesis Client Library saves its position in the stream. For starters, set it to the same as the batch interval of the streaming application. + + - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + + +3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). + + *Points to remember at runtime:* + + - Kinesis data processing is ordered per partition and occurs at-least once per message. + + - Multiple applications can read from the same Kinesis stream. Kinesis will maintain the application-specific shard and checkpoint info in DynamodDB. + + - A single Kinesis stream shard is processed by one input DStream at a time. + +

+ Spark Streaming Kinesis Architecture + +

+ + - A single Kinesis input DStream can read from multiple shards of a Kinesis stream by creating multiple KinesisRecordProcessor threads. + + - Multiple input DStreams running in separate processes/instances can read from a Kinesis stream. + + - You never need more Kinesis input DStreams than the number of Kinesis stream shards as each input DStream will create at least one KinesisRecordProcessor thread that handles a single shard. + + - Horizontal scaling is achieved by adding/removing Kinesis input DStreams (within a single process or across multiple processes/instances) - up to the total number of Kinesis stream shards per the previous point. + + - The Kinesis input DStream will balance the load between all DStreams - even across processes/instances. + + - The Kinesis input DStream will balance the load during re-shard events (merging and splitting) due to changes in load. + + - As a best practice, it's recommended that you avoid re-shard jitter by over-provisioning when possible. + + - Each Kinesis input DStream maintains its own checkpoint info. See the Kinesis Checkpointing section for more details. + + - There is no correlation between the number of Kinesis stream shards and the number of RDD partitions/shards created across the Spark cluster during input DStream processing. These are 2 independent partitioning schemes. + +#### Running the Example +To run the example, + +- Download Spark source and follow the [instructions](building-with-maven.html) to build Spark with profile *-Pkinesis-asl*. + + mvn -Pkinesis-asl -DskipTests clean package + + +- Set up Kinesis stream (see earlier section) within AWS. Note the name of the Kinesis stream and the endpoint URL corresponding to the region where the stream was created. + +- Set up the environment variables AWS_ACCESS_KEY_ID and AWS_SECRET_KEY with your AWS credentials. + +- In the Spark root directory, run the example as + +
+
+ + bin/run-example streaming.KinesisWordCountASL [Kinesis stream name] [endpoint URL] + +
+
+ + bin/run-example streaming.JavaKinesisWordCountASL [Kinesis stream name] [endpoint URL] + +
+
+ + This will wait for data to be received from the Kinesis stream. + +- To generate random string data to put onto the Kinesis stream, in another terminal, run the associated Kinesis data producer. + + bin/run-example streaming.KinesisWordCountProducerASL [Kinesis stream name] [endpoint URL] 1000 10 + + This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example. + +#### Kinesis Checkpointing +- Each Kinesis input DStream periodically stores the current position of the stream in the backing DynamoDB table. This allows the system to recover from failures and continue processing where the DStream left off. + +- Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random-backoff-retry strategy. + +- If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable. +- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). +- InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency. diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md deleted file mode 100644 index 16ad3222105a2..0000000000000 --- a/docs/streaming-kinesis.md +++ /dev/null @@ -1,59 +0,0 @@ ---- -layout: global -title: Spark Streaming Kinesis Receiver ---- - -## Kinesis -###Design -
  • The KinesisReceiver uses the Kinesis Client Library (KCL) provided by Amazon under the Amazon Software License.
  • -
  • The KCL builds on top of the Apache 2.0 licensed AWS Java SDK and provides load-balancing, fault-tolerance, checkpointing through the concept of Workers, Checkpoints, and Shard Leases.
  • -
  • The KCL uses DynamoDB to maintain all state. A DynamoDB table is created in the us-east-1 region (regardless of Kinesis stream region) during KCL initialization for each Kinesis application name.
  • -
  • A single KinesisReceiver can process many shards of a stream by spinning up multiple KinesisRecordProcessor threads.
  • -
  • You never need more KinesisReceivers than the number of shards in your stream as each will spin up at least one KinesisRecordProcessor thread.
  • -
  • Horizontal scaling is achieved by autoscaling additional KinesisReceiver (separate processes) or spinning up new KinesisRecordProcessor threads within each KinesisReceiver - up to the number of current shards for a given stream, of course. Don't forget to autoscale back down!
  • - -### Build -
  • Spark supports a Streaming KinesisReceiver, but it is not included in the default build due to Amazon Software Licensing (ASL) restrictions.
  • -
  • To build with the Kinesis Streaming Receiver and supporting ASL-licensed code, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • -
  • All KinesisReceiver-related code, examples, tests, and artifacts live in **$SPARK_HOME/extras/kinesis-asl/**.
  • -
  • Kinesis-based Spark Applications will need to link to the **spark-streaming-kinesis-asl** artifact that is built when **-Pkinesis-asl** is specified.
  • -
  • _**Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • - -###Example -
  • To build the Kinesis example, you must run the maven or sbt builds with the **-Pkinesis-asl** profile.
  • -
  • You need to setup a Kinesis stream at one of the valid Kinesis endpoints with 1 or more shards per the following: http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • -
  • Valid Kinesis endpoints can be found here: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • -
  • When running **locally**, the example automatically determines the number of threads and KinesisReceivers to spin up based on the number of shards configured for the stream. Therefore, **local[n]** is not needed when starting the example as with other streaming examples.
  • -
  • While this example could use a single KinesisReceiver which spins up multiple KinesisRecordProcessor threads to process multiple shards, I wanted to demonstrate unioning multiple KinesisReceivers as a single DStream. (It's a bit confusing in local mode.)
  • -
  • **KinesisWordCountProducerASL** is provided to generate random records into the Kinesis stream for testing.
  • -
  • The example has been configured to immediately replicate incoming stream data to another node by using (StorageLevel.MEMORY_AND_DISK_2) -
  • Spark checkpointing is disabled because the example does not use any stateful or window-based DStream operations such as updateStateByKey and reduceByWindow. If those operations are introduced, you would need to enable checkpointing or risk losing data in the case of a failure.
  • -
  • Kinesis checkpointing is enabled. This means that the example will recover from a Kinesis failure.
  • -
  • The example uses InitialPositionInStream.LATEST strategy to pull from the latest tip of the stream if no Kinesis checkpoint info exists.
  • -
  • In our example, **KinesisWordCount** is the Kinesis application name for both the Scala and Java versions. The use of this application name is described next.
  • - -###Deployment and Runtime -
  • A Kinesis application name must be unique for a given account and region.
  • -
  • A DynamoDB table and CloudWatch namespace are created during KCL initialization using this Kinesis application name. http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization
  • -
  • This DynamoDB table lives in the us-east-1 region regardless of the Kinesis endpoint URL.
  • -
  • Changing the app name or stream name could lead to Kinesis errors as only a single logical application can process a single stream.
  • -
  • If you are seeing errors after changing the app name or stream name, it may be necessary to manually delete the DynamoDB table and start from scratch.
  • -
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the KCL.
  • -
  • The KinesisReceiver uses the DefaultAWSCredentialsProviderChain for AWS credentials which searches for credentials in the following order of precedence:
    -1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    -2) Java System Properties - aws.accessKeyId and aws.secretKey
    -3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    -4) Instance profile credentials - delivered through the Amazon EC2 metadata service -
  • - -###Fault-Tolerance -
  • The combination of Spark Streaming and Kinesis creates 2 different checkpoints that may occur at different intervals.
  • -
  • Checkpointing too frequently against Kinesis will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random backoff retry strategy.
  • -
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last Kinesis checkpoint sequence number recorded per shard (stored in the DynamoDB table).
  • -
  • If no Kinesis checkpoint info exists, the KinesisReceiver will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable.
  • -
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running (and no checkpoint info is being stored.)
  • -
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data.
  • -
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency.
  • -
  • Record processing should be idempotent when possible.
  • -
  • A failed or latent KinesisRecordProcessor within the KinesisReceiver will be detected and automatically restarted by the KCL.
  • -
  • If possible, the KinesisReceiver should be shutdown cleanly in order to trigger a final checkpoint of all KinesisRecordProcessors to avoid duplicate record processing.
  • \ No newline at end of file diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 9f331ed50d2a4..41f170580f452 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -7,12 +7,12 @@ title: Spark Streaming Programming Guide {:toc} # Overview -Spark Streaming is an extension of the core Spark API that allows enables high-throughput, +Spark Streaming is an extension of the core Spark API that allows enables scalable, high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources like Kafka, Flume, Twitter, ZeroMQ, Kinesis or plain old TCP sockets and be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, -and live dashboards. In fact, you can apply Spark's in-built +and live dashboards. In fact, you can apply Spark's [machine learning](mllib-guide.html) algorithms, and [graph processing](graphx-programming-guide.html) algorithms on data streams. @@ -60,35 +60,24 @@ do is as follows.
    First, we import the names of the Spark Streaming classes, and some implicit conversions from StreamingContext into our environment, to add useful methods to -other classes we need (like DStream). - -[StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) is the -main entry point for all streaming functionality. +other classes we need (like DStream). [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) is the +main entry point for all streaming functionality. We create a local StreamingContext with two execution threads, and batch interval of 1 second. {% highlight scala %} +import org.apache.spark._ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext._ -{% endhighlight %} - -Then we create a -[StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) object. -Besides Spark's configuration, we specify that any DStream will be processed -in 1 second batches. -{% highlight scala %} -import org.apache.spark.api.java.function._ -import org.apache.spark.streaming._ -import org.apache.spark.streaming.api._ -// Create a StreamingContext with a local master -// Spark Streaming needs at least two working thread -val ssc = new StreamingContext("local[2]", "NetworkWordCount", Seconds(1)) +// Create a local StreamingContext with two working thread and batch interval of 1 second +val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") +val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} -Using this context, we then create a new DStream -by specifying the IP address and port of the data server. +Using this context, we can create a DStream that represents streaming data from a TCP +source hostname, e.g. `localhost`, and port, e.g. `9999` {% highlight scala %} -// Create a DStream that will connect to serverIP:serverPort, like localhost:9999 +// Create a DStream that will connect to hostname:port, like localhost:9999 val lines = ssc.socketTextStream("localhost", 9999) {% endhighlight %} @@ -112,7 +101,7 @@ import org.apache.spark.streaming.StreamingContext._ val pairs = words.map(word => (word, 1)) val wordCounts = pairs.reduceByKey(_ + _) -// Print a few of the counts to the console +// Print the first ten elements of each RDD generated in this DStream to the console wordCounts.print() {% endhighlight %} @@ -139,23 +128,25 @@ The complete code can be found in the Spark Streaming example First, we create a [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) object, which is the main entry point for all streaming -functionality. Besides Spark's configuration, we specify that any DStream would be processed -in 1 second batches. +functionality. We create a local StreamingContext with two execution threads, and a batch interval of 1 second. {% highlight java %} +import org.apache.spark.*; import org.apache.spark.api.java.function.*; import org.apache.spark.streaming.*; import org.apache.spark.streaming.api.java.*; import scala.Tuple2; -// Create a StreamingContext with a local master -JavaStreamingContext jssc = new JavaStreamingContext("local[2]", "JavaNetworkWordCount", new Duration(1000)) + +// Create a local StreamingContext with two working thread and batch interval of 1 second +val conf = new SparkConf().setMaster("local[2]").setAppName("NetworkWordCount") +JavaStreamingContext jssc = new JavaStreamingContext(conf, new Duration(1000)) {% endhighlight %} -Using this context, we then create a new DStream -by specifying the IP address and port of the data server. +Using this context, we can create a DStream that represents streaming data from a TCP +source hostname, e.g. `localhost`, and port, e.g. `9999` {% highlight java %} -// Create a DStream that will connect to serverIP:serverPort, like localhost:9999 +// Create a DStream that will connect to hostname:port, like localhost:9999 JavaReceiverInputDStream lines = jssc.socketTextStream("localhost", 9999); {% endhighlight %} @@ -197,7 +188,9 @@ JavaPairDStream wordCounts = pairs.reduceByKey( return i1 + i2; } }); -wordCounts.print(); // Print a few of the counts to the console + +// Print the first ten elements of each RDD generated in this DStream to the console +wordCounts.print(); {% endhighlight %} The `words` DStream is further mapped (one-to-one transformation) to a DStream of `(word, @@ -207,8 +200,8 @@ using a [Function2](api/scala/index.html#org.apache.spark.api.java.function.Func Finally, `wordCounts.print()` will print a few of the counts generated every second. Note that when these lines are executed, Spark Streaming only sets up the computation it -will perform when it is started, and no real processing has started yet. To start the processing -after all the transformations have been setup, we finally call +will perform after it is started, and no real processing has started yet. To start the processing +after all the transformations have been setup, we finally call `start` method. {% highlight java %} jssc.start(); // Start the computation @@ -235,12 +228,12 @@ Then, in a different terminal, you can start the example by using
    {% highlight bash %} -$ ./bin/run-example org.apache.spark.examples.streaming.NetworkWordCount localhost 9999 +$ ./bin/run-example streaming.NetworkWordCount localhost 9999 {% endhighlight %}
    {% highlight bash %} -$ ./bin/run-example org.apache.spark.examples.streaming.JavaNetworkWordCount localhost 9999 +$ ./bin/run-example streaming.JavaNetworkWordCount localhost 9999 {% endhighlight %}
    @@ -269,7 +262,7 @@ hello world {% highlight bash %} # TERMINAL 2: RUNNING NetworkWordCount or JavaNetworkWordCount -$ ./bin/run-example org.apache.spark.examples.streaming.NetworkWordCount localhost 9999 +$ ./bin/run-example streaming.NetworkWordCount localhost 9999 ... ------------------------------------------- Time: 1357008430000 ms @@ -281,37 +274,33 @@ Time: 1357008430000 ms -You can also use Spark Streaming directly from the Spark shell: - -{% highlight bash %} -$ bin/spark-shell -{% endhighlight %} - -... and create your StreamingContext by wrapping the existing interactive shell -SparkContext object, `sc`: - -{% highlight scala %} -val ssc = new StreamingContext(sc, Seconds(1)) -{% endhighlight %} - -When working with the shell, you may also need to send a `^D` to your netcat session -to force the pipeline to print the word counts to the console at the sink. -*************************************************************************************************** +*************************************************************************************************** +*************************************************************************************************** -# Basics +# Basic Concepts Next, we move beyond the simple example and elaborate on the basics of Spark Streaming that you need to know to write your streaming applications. ## Linking -To write your own Spark Streaming program, you will have to add the following dependency to your - SBT or Maven project: +Similar to Spark, Spark Streaming is available through Maven Central. To write your own Spark Streaming program, you will have to add the following dependency to your SBT or Maven project. + +
    +
    + + + org.apache.spark + spark-streaming_{{site.SCALA_BINARY_VERSION}} + {{site.SPARK_VERSION}} + +
    +
    - groupId = org.apache.spark - artifactId = spark-streaming_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION}} + libraryDependencies += "org.apache.spark" % "spark-streaming_{{site.SCALA_BINARY_VERSION}}" % "{{site.SPARK_VERSION}}" +
    +
    For ingesting data from sources like Kafka, Flume, and Kinesis that are not present in the Spark Streaming core @@ -319,68 +308,120 @@ Streaming core artifact `spark-streaming-xyz_{{site.SCALA_BINARY_VERSION}}` to the dependencies. For example, some of the common ones are as follows. - + - - +
    SourceArtifact
    Kafka spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}
    Flume spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}
    Kinesis
    spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}} [Apache Software License]
    Twitter spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}
    ZeroMQ spark-streaming-zeromq_{{site.SCALA_BINARY_VERSION}}
    MQTT spark-streaming-mqtt_{{site.SCALA_BINARY_VERSION}}
    Kinesis
    (built separately)
    kinesis-asl_{{site.SCALA_BINARY_VERSION}}
    For an up-to-date list, please refer to the -[Apache repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION}}%22) +[Apache repository](http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.spark%22%20AND%20v%3A%22{{site.SPARK_VERSION_SHORT}}%22) for the full list of supported sources and artifacts. -## Initializing +*** + +## Initializing StreamingContext + +To initialize a Spark Streaming program, a **StreamingContext** object has to be created which is the main entry point of all Spark Streaming functionality.
    -To initialize a Spark Streaming program in Scala, a -[`StreamingContext`](api/scala/index.html#org.apache.spark.streaming.StreamingContext) -object has to be created, which is the main entry point of all Spark Streaming functionality. -A `StreamingContext` object can be created by using +A [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) object can be created from a [SparkConf](api/scala/index.html#org.apache.spark.SparkConf) object. {% highlight scala %} -new StreamingContext(master, appName, batchDuration, [sparkHome], [jars]) +import org.apache.spark._ +import org.apache.spark.streaming._ + +val conf = new SparkConf().setAppName(appName).setMaster(master) +val ssc = new StreamingContext(conf, Seconds(1)) {% endhighlight %} -
    -
    -To initialize a Spark Streaming program in Java, a -[`JavaStreamingContext`](api/scala/index.html#org.apache.spark.streaming.api.java.JavaStreamingContext) -object has to be created, which is the main entry point of all Spark Streaming functionality. -A `JavaStreamingContext` object can be created by using +The `appName` parameter is a name for your application to show on the cluster UI. +`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls), +or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster, +you will not want to hardcode `master` in the program, +but rather [launch the application with `spark-submit`](submitting-applications.html) and +receive it there. However, for local testing and unit tests, you can pass "local[\*]" to run Spark Streaming +in-process (detects the number of cores in the local system). Note that this internally creates a [SparkContext](api/scala/index.html#org.apache.spark.SparkContext) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. + +The batch interval must be set based on the latency requirements of your application +and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) +section for more details. + +A `StreamingContext` object can also be created from an existing `SparkContext` object. {% highlight scala %} -new JavaStreamingContext(master, appName, batchInterval, [sparkHome], [jars]) +import org.apache.spark.streaming._ + +val sc = ... // existing SparkContext +val ssc = new StreamingContext(sc, Seconds(1)) {% endhighlight %} + +
    -
    +
    + +A [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) object can be created from a [SparkConf](api/java/index.html?org/apache/spark/SparkConf.html) object. + +{% highlight java %} +import org.apache.spark.*; +import org.apache.spark.streaming.api.java.*; -The `master` parameter is a standard [Spark cluster URL](programming-guide.html#master-urls) -and can be "local" for local testing. The `appName` is a name of your program, -which will be shown on your cluster's web UI. The `batchInterval` is the size of the batches, -as explained earlier. Finally, the last two parameters are needed to deploy your code to a cluster - if running in distributed mode, as described in the - [Spark programming guide](programming-guide.html#deploying-code-on-a-cluster). - Additionally, the underlying SparkContext can be accessed as -`ssc.sparkContext`. +SparkConf conf = new SparkConf().setAppName(appName).setMaster(master); +JavaStreamingContext ssc = new JavaStreamingContext(conf, Duration(1000)); +{% endhighlight %} + +The `appName` parameter is a name for your application to show on the cluster UI. +`master` is a [Spark, Mesos or YARN cluster URL](submitting-applications.html#master-urls), +or a special __"local[\*]"__ string to run in local mode. In practice, when running on a cluster, +you will not want to hardcode `master` in the program, +but rather [launch the application with `spark-submit`](submitting-applications.html) and +receive it there. However, for local testing and unit tests, you can pass "local[*]" to run Spark Streaming +in-process. Note that this internally creates a [JavaSparkContext](api/java/index.html?org/apache/spark/api/java/JavaSparkContext.html) (starting point of all Spark functionality) which can be accessed as `ssc.sparkContext`. The batch interval must be set based on the latency requirements of your application and available cluster resources. See the [Performance Tuning](#setting-the-right-batch-size) section for more details. -## DStreams -*Discretized Stream* or *DStream* is the basic abstraction provided by Spark Streaming. +A `JavaStreamingContext` object can also be created from an existing `JavaSparkContext`. + +{% highlight java %} +import org.apache.spark.streaming.api.java.*; + +JavaSparkContext sc = ... //existing JavaSparkContext +JavaStreamingContext ssc = new JavaStreamingContext(sc, new Duration(1000)); +{% endhighlight %} +
    +
    + +After a context is defined, you have to do the follow steps. +1. Define the input sources. +1. Setup the streaming computations. +1. Start the receiving and procesing of data using `streamingContext.start()`. +1. The processing will continue until `streamingContext.stop()` is called. + +##### Points to remember: +{:.no_toc} +- Once a context has been started, no new streaming computations can be setup or added to it. +- Once a context has been stopped, it cannot be started (that is, re-used) again. +- Only one StreamingContext can be active in a JVM at the same time. +- stop() on StreamingContext also stops the SparkContext. To stop only the StreamingContext, set optional parameter of `stop()` called `stopSparkContext` to false. +- A SparkContext can be re-used to create multiple StreamingContexts, as long as the previous StreamingContext is stopped (without stopping the SparkContext) before the next StreamingContext is created. + +*** + +## Discretized Streams (DStreams) +**Discretized Stream** or **DStream** is the basic abstraction provided by Spark Streaming. It represents a continuous stream of data, either the input data stream received from source, or the processed data stream generated by transforming the input stream. Internally, -it is represented by a continuous sequence of RDDs, which is Spark's abstraction of an immutable, -distributed dataset. Each RDD in a DStream contains data from a certain interval, +a DStream is represented by a continuous series of RDDs, which is Spark's abstraction of an immutable, +distributed dataset (see [Spark Programming Guide](programming-guide.html#resilient-distributed-datasets-rdds) for more details). Each RDD in a DStream contains data from a certain interval, as shown in the following figure.

    @@ -392,8 +433,8 @@ as shown in the following figure. Any operation applied on a DStream translates to operations on the underlying RDDs. For example, in the [earlier example](#a-quick-example) of converting a stream of lines to words, -the `flatmap` operation is applied on each RDD in the `lines` DStream to generate the RDDs of the - `words` DStream. This is shown the following figure. +the `flatMap` operation is applied on each RDD in the `lines` DStream to generate the RDDs of the + `words` DStream. This is shown in the following figure.

    -

    -{% highlight scala %} -ssc.fileStream(dataDirectory) -{% endhighlight %} -
    -
    -{% highlight java %} -jssc.fileStream(dataDirectory); -{% endhighlight %} -
    - +
    +
    + streamingContext.fileStream[keyClass, valueClass, inputFormatClass](dataDirectory) +
    +
    + streamingContext.fileStream(dataDirectory); +
    +
    + + Spark Streaming will monitor the directory `dataDirectory` and process any files created in that directory (files written in nested directories not supported). Note that + + + The files must have the same data format. + + The files must be created in the `dataDirectory` by atomically *moving* or *renaming* them into + the data directory. + + Once moved, the files must not be changed. So if the files are being continuously appended, the new data will not be read. -Spark Streaming will monitor the directory `dataDirectory` for any Hadoop-compatible filesystem -and process any files created in that directory. Note that + For simple text files, there is an easier method `streamingContext.textFileStream(dataDirectory)`. And file streams do not require running a receiver, hence does not require allocating cores. - * The files must have the same data format. - * The files must be created in the `dataDirectory` by atomically *moving* or *renaming* them into - the data directory. - * Once moved the files must not be changed. +- **Streams based on Custom Actors:** DStreams can be created with data streams received through Akka actors by using `streamingContext.actorStream(actorProps, actor-name)`. See the [Custom Receiver Guide](#implementing-and-using-a-custom-actor-based-receiver) for more details. -For more details on streams from files, Akka actors and sockets, +- **Queue of RDDs as a Stream:** For testing a Spark Streaming application with test data, one can also create a DStream based on a queue of RDDs, using `streamingContext.queueStream(queueOfRDDs)`. Each RDD pushed into the queue will be treated as a batch of data in the DStream, and processed like a stream. + +For more details on streams from sockets, files, and actors, see the API documentations of the relevant functions in [StreamingContext](api/scala/index.html#org.apache.spark.streaming.StreamingContext) for -Scala and [JavaStreamingContext](api/scala/index.html#org.apache.spark.streaming.api.java.JavaStreamingContext) - for Java. +Scala and [JavaStreamingContext](api/java/index.html?org/apache/spark/streaming/api/java/JavaStreamingContext.html) for Java. + +### Advanced Sources +{:.no_toc} +This category of sources require interfacing with external non-Spark libraries, some of them with complex dependencies (e.g., Kafka and Flume). Hence, to minimize issues related to version conflicts of dependencies, the functionality to create DStreams from these sources have been moved to separate libraries, that can be [linked to](#linking) explicitly as necessary. For example, if you want to create a DStream using data from Twitter's stream of tweets, you have to do the following. -Additional functionality for creating DStreams from sources such as Kafka, Flume, Kinesis, and Twitter -can be imported by adding the right dependencies as explained in an -[earlier](#linking) section. To take the -case of Kafka, after adding the artifact `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` to the -project dependencies, you can create a DStream from Kafka as +1. *Linking*: Add the artifact `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` to the SBT/Maven project dependencies. +1. *Programming*: Import the `TwitterUtils` class and create a DStream with `TwitterUtils.createStream` as shown below. +1. *Deploying*: Generate an uber JAR with all the dependencies (including the dependency `spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` and its transitive dependencies) and then deploy the application. This is further explained in the [Deploying section](#deploying-applications).
    {% highlight scala %} -import org.apache.spark.streaming.kafka._ -KafkaUtils.createStream(ssc, kafkaParams, ...) +import org.apache.spark.streaming.twitter._ + +TwitterUtils.createStream(ssc) {% endhighlight %}
    {% highlight java %} -import org.apache.spark.streaming.kafka.*; -KafkaUtils.createStream(jssc, kafkaParams, ...); +import org.apache.spark.streaming.twitter.*; + +TwitterUtils.createStream(jssc); {% endhighlight %}
    -For more details on these additional sources, see the corresponding [API documentation](#where-to-go-from-here). -Furthermore, you can also implement your own custom receiver for your sources. See the -[Custom Receiver Guide](streaming-custom-receivers.html). +Note that these advanced sources are not available in the `spark-shell`, hence applications based on these +advanced sources cannot be tested in the shell. + +Some of these advanced sources are as follows. + +- **Twitter:** Spark Streaming's TwitterUtils uses Twitter4j 3.0.3 to get the public stream of tweets using + [Twitter's Streaming API](https://dev.twitter.com/docs/streaming-apis). Authentication information + can be provided by any of the [methods](http://twitter4j.org/en/configuration.html) supported by + Twitter4J library. You can either get the public stream, or get the filtered stream based on a + keywords. See the API documentation ([Scala](api/scala/index.html#org.apache.spark.streaming.twitter.TwitterUtils$), [Java](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html)) and examples ([TwitterPopularTags]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterPopularTags.scala) and + [TwitterAlgebirdCMS]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdCMS.scala)). + +- **Flume:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can received data from Flume 1.4.0. See the [Flume Integration Guide](streaming-flume-integration.html) for more details. -### Kinesis -[Kinesis](streaming-kinesis.html) +- **Kafka:** Spark Streaming {{site.SPARK_VERSION_SHORT}} can receive data from Kafka 0.8.0. See the [Kafka Integration Guide](streaming-kafka-integration.html) for more details. -## Operations -There are two kinds of DStream operations - _transformations_ and _output operations_. Similar to -RDD transformations, DStream transformations operate on one or more DStreams to create new DStreams -with transformed data. After applying a sequence of transformations to the input streams, output -operations need to called, which write data out to an external data sink, such as a filesystem or a -database. +- **Kinesis:** See the [Kinesis Integration Guide](streaming-kinesis-integration.html) for more details. -### Transformations -DStreams support many of the transformations available on normal Spark RDD's. Some of the -common ones are as follows. +### Custom Sources +{:.no_toc} +Input DStreams can also be created out of custom data sources. All you have to do is implement an user-defined **receiver** (see next section to understand what that is) that can receive data from the custom sources and push it into Spark. See the +[Custom Receiver Guide](streaming-custom-receivers.html) for details. + +*** + +## Transformations on DStreams +Similar to that of RDDs, transformations allow the data from the input DStream to be modified. +DStreams support many of the transformations available on normal Spark RDD's. +Some of the common ones are as follows. @@ -557,8 +633,8 @@ common ones are as follows. The last two transformations are worth highlighting again. -

    UpdateStateByKey Operation

    - +#### UpdateStateByKey Operation +{:.no_toc} The `updateStateByKey` operation allows you to maintain arbitrary state while continuously updating it with new information. To use this, you will have to do two steps. @@ -616,8 +692,8 @@ the `(word, 1)` pairs) and the `runningCount` having the previous count. For the Scala code, take a look at the example [StatefulNetworkWordCount]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala). -

    Transform Operation

    - +#### Transform Operation +{:.no_toc} The `transform` operation (along with its variations like `transformWith`) allows arbitrary RDD-to-RDD functions to be applied on a DStream. It can be used to apply any RDD operation that is not exposed in the DStream API. @@ -662,8 +738,8 @@ JavaPairDStream cleanedDStream = wordCounts.transform( In fact, you can also use [machine learning](mllib-guide.html) and [graph computation](graphx-programming-guide.html) algorithms in the `transform` method. -

    Window Operations

    - +#### Window Operations +{:.no_toc} Finally, Spark Streaming also provides *windowed computations*, which allow you to apply transformations over a sliding window of data. This following figure illustrates this sliding window. @@ -678,11 +754,11 @@ window. As shown in the figure, every time the window *slides* over a source DStream, the source RDDs that fall within the window are combined and operated upon to produce the RDDs of the windowed DStream. In this specific case, the operation is applied over last 3 time -units of data, and slides by 2 time units. This shows that any window-based operation needs to +units of data, and slides by 2 time units. This shows that any window operation needs to specify two parameters. * window length - The duration of the window (3 in the figure) - * slide interval - The interval at which the window-based operation is performed (2 in + * sliding interval - The interval at which the window operation is performed (2 in the figure). These two parameters must be multiples of the batch interval of the source DStream (1 in the @@ -720,7 +796,7 @@ JavaPairDStream windowedWordCounts = pairs.reduceByKeyAndWindow -Some of the common window-based operations are as follows. All of these operations take the +Some of the common window operations are as follows. All of these operations take the said two parameters - windowLength and slideInterval.
    TransformationMeaning
    @@ -778,21 +854,27 @@ said two parameters - windowLength and slideInterval.
    -### Output Operations -When an output operator is called, it triggers the computation of a stream. Currently the following -output operators are defined: + +The complete list of DStream transformations is available in the API documentation. For the Scala API, +see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) +and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions). +For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) +and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html). + +*** + +## Output Operations on DStreams +Output operations allow DStream's data to be pushed out external systems like a database or a file systems. +Since the output operations actually allow the transformed data to be consumed by external systems, +they trigger the actual execution of all the DStream transformations (similar to actions for RDDs). +Currently, the following output operations are defined: - - - - - + @@ -811,17 +893,84 @@ output operators are defined: + + + +
    Output OperationMeaning
    print() Prints first ten elements of every batch of data in a DStream on the driver.
    foreachRDD(func) The fundamental output operator. Applies a function, func, to each RDD generated from - the stream. This function should have side effects, such as printing output, saving the RDD to - external files, or writing it over the network to an external system. Prints first ten elements of every batch of data in a DStream on the driver. + This is useful for development and debugging.
    saveAsObjectFiles(prefix, [suffix]) Save this DStream's contents as a Hadoop file. The file name at each batch interval is generated based on prefix and suffix: "prefix-TIME_IN_MS[.suffix]".
    foreachRDD(func) The most generic output operator that applies a function, func, to each RDD generated from + the stream. This function should push the data in each RDD to a external system, like saving the RDD to + files, or writing it over the network to a database. Note that the function func is executed + at the driver, and will usually have RDD actions in it that will force the computation of the streaming RDDs.
    +### Design Patterns for using foreachRDD +{:.no_toc} +`dstream.foreachRDD` is a powerful primitive that allows data to sent out to external systems. +However, it is important to understand how to use this primitive correctly and efficiently. +Some of the common mistakes to avoid are as follows. -The complete list of DStream operations is available in the API documentation. For the Scala API, -see [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) -and [PairDStreamFunctions](api/scala/index.html#org.apache.spark.streaming.dstream.PairDStreamFunctions). -For the Java API, see [JavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaDStream.html) -and [JavaPairDStream](api/java/index.html?org/apache/spark/streaming/api/java/JavaPairDStream.html). +- Often writing data to external system requires creating a connection object +(e.g. TCP connection to a remote server) and using it to send data to a remote system. +For this purpose, a developer may inadvertantly try creating a connection object at +the Spark driver, but try to use it in a Spark worker to save records in the RDDs. +For example (in Scala), + + dstream.foreachRDD(rdd => { + val connection = createNewConnection() // executed at the driver + rdd.foreach(record => { + connection.send(record) // executed at the worker + }) + }) + + This is incorrect as this requires the connection object to be serialized and sent from the driver to the worker. Such connection objects are rarely transferrable across machines. This error may manifest as serialization errors (connection object not serializable), initialization errors (connection object needs to be initialized at the workers), etc. The correct solution is to create the connection object at the worker. + +- However, this can lead to another common mistake - creating a new connection for every record. For example, + + dstream.foreachRDD(rdd => { + rdd.foreach(record => { + val connection = createNewConnection() + connection.send(record) + connection.close() + }) + }) + + Typically, creating a connection object has time and resource overheads. Therefore, creating and destroying a connection object for each record can incur unnecessarily high overheads and can significantly reduce the overall throughput of the system. A better solution is to use `rdd.foreachPartition` - create a single connection object and send all the records in a RDD partition using that connection. + + dstream.foreachRDD(rdd => { + rdd.foreachPartition(partitionOfRecords => { + val connection = createNewConnection() + partitionOfRecords.foreach(record => connection.send(record)) + connection.close() + }) + }) + + This amortizes the connection creation overheads over many records. -## Persistence +- Finally, this can be further optimized by reusing connection objects across multiple RDDs/batches. + One can maintain a static pool of connection objects than can be reused as + RDDs of multiple batches are pushed to the external system, thus further reducing the overheads. + + dstream.foreachRDD(rdd => { + rdd.foreachPartition(partitionOfRecords => { + // ConnectionPool is a static, lazily initialized pool of connections + val connection = ConnectionPool.getConnection() + partitionOfRecords.foreach(record => connection.send(record)) + ConnectionPool.returnConnection(connection) // return to the pool for future reuse + }) + }) + + Note that the connections in the pool should be lazily created on demand and timed out if not used for a while. This achieves the most efficient sending of data to external systems. + + +##### Other points to remember: +{:.no_toc} +- DStreams are executed lazily by the output operations, just like RDDs are lazily executed by RDD actions. Specifically, RDD actions inside the DStream output operations force the processing of the received data. Hence, if your application does not have any output operation, or has output operations like `dstream.foreachRDD()` without any RDD action inside them, then nothing will get executed. The system will simply receive the data and discard it. + +- By default, output operations are executed one-at-a-time. And they are executed in the order they are defined in the application. + +*** + +## Caching / Persistence Similar to RDDs, DStreams also allow developers to persist the stream's data in memory. That is, using `persist()` method on a DStream would automatically persist every RDD of that DStream in memory. This is useful if the data in the DStream will be computed multiple times (e.g., multiple @@ -838,7 +987,9 @@ memory. This is further discussed in the [Performance Tuning](#memory-tuning) se information on different persistence levels can be found in [Spark Programming Guide](programming-guide.html#rdd-persistence). -## RDD Checkpointing +*** + +## Checkpointing A _stateful operation_ is one which operates over multiple batches of data. This includes all window-based operations and the `updateStateByKey` operation. Since stateful operations have a dependency on previous batches of data, they continuously accumulate metadata over time. @@ -867,10 +1018,19 @@ For DStreams that must be checkpointed (that is, DStreams created by `updateStat `reduceByKeyAndWindow` with inverse function), the checkpoint interval of the DStream is by default set to a multiple of the DStream's sliding interval such that its at least 10 seconds. -## Deployment +*** + +## Deploying Applications A Spark Streaming application is deployed on a cluster in the same way as any other Spark application. Please refer to the [deployment guide](cluster-overview.html) for more details. +Note that the applications +that use [advanced sources](#advanced-sources) (e.g. Kafka, Flume, Twitter) are also required to package the +extra artifact they link to, along with their dependencies, in the JAR that is used to deploy the application. +For example, an application using `TwitterUtils` will have to include +`spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}}` and all its transitive +dependencies in the application JAR. + If a running Spark Streaming application needs to be upgraded (with new application code), then there are two possible mechanism. @@ -889,7 +1049,9 @@ application left off. Note that this can be done only with input sources that su (like Kafka, and Flume) as data needs to be buffered while the previous application down and the upgraded application is not yet up. -## Monitoring +*** + +## Monitoring Applications Beyond Spark's [monitoring capabilities](monitoring.html), there are additional capabilities specific to Spark Streaming. When a StreamingContext is used, the [Spark web UI](monitoring.html#web-interfaces) shows @@ -912,22 +1074,18 @@ The progress of a Spark Streaming program can also be monitored using the which allows you to get receiver status and processing times. Note that this is a developer API and it is likely to be improved upon (i.e., more information reported) in the future. -*************************************************************************************************** +*************************************************************************************************** +*************************************************************************************************** # Performance Tuning Getting the best performance of a Spark Streaming application on a cluster requires a bit of tuning. This section explains a number of the parameters and configurations that can tuned to improve the performance of you application. At a high level, you need to consider two things: -
      -
    1. - Reducing the processing time of each batch of data by efficiently using cluster resources. -
    2. -
    3. - Setting the right batch size such that the batches of data can be processed as fast as they - are received (that is, data processing keeps up with the data ingestion). -
    4. -
    +1. Reducing the processing time of each batch of data by efficiently using cluster resources. + +2. Setting the right batch size such that the batches of data can be processed as fast as they + are received (that is, data processing keeps up with the data ingestion). ## Reducing the Processing Time of each Batch There are a number of optimizations that can be done in Spark to minimize the processing time of @@ -935,15 +1093,41 @@ each batch. These have been discussed in detail in [Tuning Guide](tuning.html). highlights some of the most important ones. ### Level of Parallelism in Data Receiving +{:.no_toc} Receiving data over the network (like Kafka, Flume, socket, etc.) requires the data to deserialized and stored in Spark. If the data receiving becomes a bottleneck in the system, then consider parallelizing the data receiving. Note that each input DStream creates a single receiver (running on a worker machine) that receives a single stream of data. Receiving multiple data streams can therefore be achieved by creating multiple input DStreams and configuring them to receive different partitions of the data stream from the source(s). -For example, a single Kafka input stream receiving two topics of data can be split into two +For example, a single Kafka input DStream receiving two topics of data can be split into two Kafka input streams, each receiving only one topic. This would run two receivers on two workers, -thus allowing data to be received in parallel, and increasing overall throughput. +thus allowing data to be received in parallel, and increasing overall throughput. These multiple +DStream can be unioned together to create a single DStream. Then the transformations that was +being applied on the single input DStream can applied on the unified stream. This is done as follows. + +
    +
    +{% highlight scala %} +val numStreams = 5 +val kafkaStreams = (1 to numStreams).map { i => KafkaUtils.createStream(...) } +val unifiedStream = streamingContext.union(kafkaStreams) +unifiedStream.print() +{% endhighlight %} +
    +
    +{% highlight java %} +int numStreams = 5; +List> kafkaStreams = new ArrayList>(numStreams); +for (int i = 0; i < numStreams; i++) { + kafkaStreams.add(KafkaUtils.createStream(...)); +} +JavaPairDStream unifiedStream = streamingContext.union(kafkaStreams.get(0), kafkaStreams.subList(1, kafkaStreams.size())); +unifiedStream.print(); +{% endhighlight %} +
    +
    + Another parameter that should be considered is the receiver's blocking interval. For most receivers, the received data is coalesced together into large blocks of data before storing inside Spark's memory. @@ -958,7 +1142,8 @@ This distributes the received batches of data across specified number of machine before further processing. ### Level of Parallelism in Data Processing -Cluster resources maybe under-utilized if the number of parallel tasks used in any stage of the +{:.no_toc} +Cluster resources can be under-utilized if the number of parallel tasks used in any stage of the computation is not high enough. For example, for distributed reduce operations like `reduceByKey` and `reduceByKeyAndWindow`, the default number of parallel tasks is decided by the [config property] (configuration.html#spark-properties) `spark.default.parallelism`. You can pass the level of @@ -968,6 +1153,7 @@ documentation), or set the [config property](configuration.html#spark-properties `spark.default.parallelism` to change the default. ### Data Serialization +{:.no_toc} The overhead of data serialization can be significant, especially when sub-second batch sizes are to be achieved. There are two aspects to it. @@ -980,6 +1166,7 @@ The overhead of data serialization can be significant, especially when sub-secon serialization format. Hence, the deserialization overhead of input data may be a bottleneck. ### Task Launching Overheads +{:.no_toc} If the number of tasks launched per second is high (say, 50 or more per second), then the overhead of sending out tasks to the slaves maybe significant and will make it hard to achieve sub-second latencies. The overhead can be reduced by the following changes: @@ -994,6 +1181,8 @@ latencies. The overhead can be reduced by the following changes: These changes may reduce batch processing time by 100s of milliseconds, thus allowing sub-second batch size to be viable. +*** + ## Setting the Right Batch Size For a Spark Streaming application running on a cluster to be stable, the system should be able to process data as fast as it is being received. In other words, batches of data should be processed @@ -1022,6 +1211,8 @@ data rate and/or reducing the batch size. Note that momentary increase in the de temporary data rate increases maybe fine as long as the delay reduces back to a low value (i.e., less than batch size). +*** + ## Memory Tuning Tuning the memory usage and GC behavior of Spark applications have been discussed in great detail in the [Tuning Guide](tuning.html). It is recommended that you read that. In this section, @@ -1037,7 +1228,7 @@ Even though keeping the data serialized incurs higher serialization/deserializat it significantly reduces GC pauses. * **Clearing persistent RDDs**: By default, all persistent RDDs generated by Spark Streaming will - be cleared from memory based on Spark's in-built policy (LRU). If `spark.cleaner.ttl` is set, + be cleared from memory based on Spark's built-in policy (LRU). If `spark.cleaner.ttl` is set, then persistent RDDs that are older than that value are periodically cleared. As mentioned [earlier](#operation), this needs to be careful set based on operations used in the Spark Streaming program. However, a smarter unpersisting of RDDs can be enabled by setting the @@ -1051,7 +1242,8 @@ minimizes the variability of GC pauses. Even though concurrent GC is known to re overall processing throughput of the system, its use is still recommended to achieve more consistent batch processing times. -*************************************************************************************************** +*************************************************************************************************** +*************************************************************************************************** # Fault-tolerance Properties In this section, we are going to discuss the behavior of Spark Streaming application in the event @@ -1124,7 +1316,7 @@ def functionToCreateContext(): StreamingContext = { ssc } -// Get StreaminContext from checkpoint data or create a new one +// Get StreamingContext from checkpoint data or create a new one val context = StreamingContext.getOrCreate(checkpointDirectory, functionToCreateContext _) // Do additional setup on context that needs to be done, @@ -1178,10 +1370,7 @@ context.awaitTermination(); If the `checkpointDirectory` exists, then the context will be recreated from the checkpoint data. If the directory does not exist (i.e., running for the first time), then the function `contextFactory` will be called to create a new -context and set up the DStreams. See the Scala example -[JavaRecoverableWordCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/JavaRecoverableWordCount.scala) -(note that this example is missing in the 0.9 release, so you can test it using the master branch). -This example appends the word counts of network data into a file. +context and set up the DStreams. You can also explicitly create a `JavaStreamingContext` from the checkpoint data and start the computation by using `new JavaStreamingContext(checkpointDirectory)`. @@ -1208,7 +1397,8 @@ automatically restarted, and the word counts will cont For other deployment environments like Mesos and Yarn, you have to restart the driver through other mechanisms. -

    Recovery Semantics

    +#### Recovery Semantics +{:.no_toc} There are two different failure behaviors based on which input sources are used. @@ -1306,7 +1496,8 @@ in the file. This is what the sequence of outputs would be with and without a dr If the driver had crashed in the middle of the processing of time 3, then it will process time 3 and output 30 after recovery. -*************************************************************************************************** +*************************************************************************************************** +*************************************************************************************************** # Migration Guide from 0.9.1 or below to 1.x Between Spark 0.9.1 and Spark 1.0, there were a few API changes made to ensure future API stability. @@ -1332,7 +1523,7 @@ replaced by [Receiver](api/scala/index.html#org.apache.spark.streaming.receiver. the following advantages. * Methods like `stop` and `restart` have been added to for better control of the lifecycle of a receiver. See -the [custom receiver guide](streaming-custom-receiver.html) for more details. +the [custom receiver guide](streaming-custom-receivers.html) for more details. * Custom receivers can be implemented using both Scala and Java. To migrate your existing custom receivers from the earlier NetworkReceiver to the new Receiver, you have @@ -1356,6 +1547,7 @@ the `org.apache.spark.streaming.receivers` package were also moved to [`org.apache.spark.streaming.receiver`](api/scala/index.html#org.apache.spark.streaming.receiver.package) package and renamed for better clarity. +*************************************************************************************************** *************************************************************************************************** # Where to Go from Here @@ -1366,6 +1558,7 @@ package and renamed for better clarity. [DStream](api/scala/index.html#org.apache.spark.streaming.dstream.DStream) * [KafkaUtils](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$), [FlumeUtils](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$), + [KinesisUtils](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$), [TwitterUtils](api/scala/index.html#org.apache.spark.streaming.twitter.TwitterUtils$), [ZeroMQUtils](api/scala/index.html#org.apache.spark.streaming.zeromq.ZeroMQUtils$), and [MQTTUtils](api/scala/index.html#org.apache.spark.streaming.mqtt.MQTTUtils$) @@ -1375,6 +1568,7 @@ package and renamed for better clarity. [PairJavaDStream](api/java/index.html?org/apache/spark/streaming/api/java/PairJavaDStream.html) * [KafkaUtils](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html), [FlumeUtils](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html), + [KinesisUtils](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) [TwitterUtils](api/java/index.html?org/apache/spark/streaming/twitter/TwitterUtils.html), [ZeroMQUtils](api/java/index.html?org/apache/spark/streaming/zeromq/ZeroMQUtils.html), and [MQTTUtils](api/java/index.html?org/apache/spark/streaming/mqtt/MQTTUtils.html) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 7e25df57ee45b..bfd07593b92ed 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -38,9 +38,12 @@ from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType from boto import ec2 +DEFAULT_SPARK_VERSION = "1.0.0" + # A URL prefix from which to fetch AMI information AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list" + class UsageError(Exception): pass @@ -56,10 +59,10 @@ def parse_args(): help="Show this help message and exit") parser.add_option( "-s", "--slaves", type="int", default=1, - help="Number of slaves to launch (default: 1)") + help="Number of slaves to launch (default: %default)") parser.add_option( "-w", "--wait", type="int", default=120, - help="Seconds to wait for nodes to start (default: 120)") + help="Seconds to wait for nodes to start (default: %default)") parser.add_option( "-k", "--key-pair", help="Key pair to use on instances") @@ -68,7 +71,7 @@ def parse_args(): help="SSH private key file to use for logging into instances") parser.add_option( "-t", "--instance-type", default="m1.large", - help="Type of instance to launch (default: m1.large). " + + help="Type of instance to launch (default: %default). " + "WARNING: must be 64-bit; small instances won't work") parser.add_option( "-m", "--master-instance-type", default="", @@ -83,15 +86,15 @@ def parse_args(): "between zones applies)") parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") parser.add_option( - "-v", "--spark-version", default="1.0.0", - help="Version of Spark to use: 'X.Y.Z' or a specific git hash") + "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, + help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") parser.add_option( "--spark-git-repo", default="https://github.com/apache/spark", help="Github repo from which to checkout supplied commit hash") parser.add_option( "--hadoop-major-version", default="1", - help="Major version of Hadoop (default: 1)") + help="Major version of Hadoop (default: %default)") parser.add_option( "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + @@ -102,26 +105,34 @@ def parse_args(): "(for debugging)") parser.add_option( "--ebs-vol-size", metavar="SIZE", type="int", default=0, - help="Attach a new EBS volume of size SIZE (in GB) to each node as " + - "/vol. The volumes will be deleted when the instances terminate. " + - "Only possible on EBS-backed AMIs.") + help="Size (in GB) of each EBS volume.") + parser.add_option( + "--ebs-vol-type", default="standard", + help="EBS volume type (e.g. 'gp2', 'standard').") + parser.add_option( + "--ebs-vol-num", type="int", default=1, + help="Number of EBS volumes to attach to each node as /vol[x]. " + + "The volumes will be deleted when the instances terminate. " + + "Only possible on EBS-backed AMIs. " + + "EBS volumes are only attached if --ebs-vol-size > 0." + + "Only support up to 8 EBS volumes.") parser.add_option( "--swap", metavar="SWAP", type="int", default=1024, - help="Swap space to set up per node, in MB (default: 1024)") + help="Swap space to set up per node, in MB (default: %default)") parser.add_option( "--spot-price", metavar="PRICE", type="float", help="If specified, launch slaves as spot instances with the given " + "maximum price (in dollars)") parser.add_option( "--ganglia", action="store_true", default=True, - help="Setup Ganglia monitoring on cluster (default: on). NOTE: " + + help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " + "the Ganglia page will be publicly accessible") parser.add_option( "--no-ganglia", action="store_false", dest="ganglia", help="Disable Ganglia monitoring for the cluster") parser.add_option( "-u", "--user", default="root", - help="The SSH user you want to connect as (default: root)") + help="The SSH user you want to connect as (default: %default)") parser.add_option( "--delete-groups", action="store_true", default=False, help="When destroying a cluster, delete the security groups that were created.") @@ -130,7 +141,7 @@ def parse_args(): help="Launch fresh slaves, but use an existing stopped master if possible") parser.add_option( "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: 1)") + help="Number of instances per worker: variable SPARK_WORKER_INSTANCES (default: %default)") parser.add_option( "--master-opts", type="string", default="", help="Extra options to give to master through SPARK_MASTER_OPTS variable " + @@ -143,7 +154,7 @@ def parse_args(): help="Use this prefix for the security group rather than the cluster name.") parser.add_option( "--authorized-address", type="string", default="0.0.0.0/0", - help="Address to authorize on created security groups (default: 0.0.0.0/0)") + help="Address to authorize on created security groups (default: %default)") parser.add_option( "--additional-security-group", type="string", default="", help="Additional security group to place the machines in") @@ -234,10 +245,10 @@ def get_spark_ami(opts): "cg1.4xlarge": "hvm", "hs1.8xlarge": "pvm", "hi1.4xlarge": "pvm", - "m3.medium": "pvm", - "m3.large": "pvm", - "m3.xlarge": "pvm", - "m3.2xlarge": "pvm", + "m3.medium": "hvm", + "m3.large": "hvm", + "m3.xlarge": "hvm", + "m3.2xlarge": "hvm", "cr1.8xlarge": "hvm", "i2.xlarge": "hvm", "i2.2xlarge": "hvm", @@ -334,7 +345,6 @@ def launch_cluster(conn, opts, cluster_name): if opts.ami is None: opts.ami = get_spark_ami(opts) - additional_groups = [] if opts.additional_security_group: additional_groups = [sg @@ -348,13 +358,16 @@ def launch_cluster(conn, opts, cluster_name): print >> stderr, "Could not find AMI " + opts.ami sys.exit(1) - # Create block device mapping so that we can add an EBS volume if asked to + # Create block device mapping so that we can add EBS volumes if asked to. + # The first drive is attached as /dev/sds, 2nd as /dev/sdt, ... /dev/sdz block_map = BlockDeviceMapping() if opts.ebs_vol_size > 0: - device = EBSBlockDeviceType() - device.size = opts.ebs_vol_size - device.delete_on_termination = True - block_map["/dev/sdv"] = device + for i in range(opts.ebs_vol_num): + device = EBSBlockDeviceType() + device.size = opts.ebs_vol_size + device.volume_type = opts.ebs_vol_type + device.delete_on_termination = True + block_map["/dev/sd" + chr(ord('s') + i)] = device # AWS ignores the AMI-specified block device mapping for M3 (see SPARK-3342). if opts.instance_type.startswith('m3.'): @@ -484,6 +497,7 @@ def launch_cluster(conn, opts, cluster_name): # Return all the instances return (master_nodes, slave_nodes) + def tag_instance(instance, name): for i in range(0, 5): try: @@ -496,9 +510,12 @@ def tag_instance(instance, name): # Get the EC2 instances in an existing cluster if available. # Returns a tuple of lists of EC2 instance objects for the masters and slaves + + def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): print "Searching for existing cluster " + cluster_name + "..." - # Search all the spot instance requests, and copy any tags from the spot instance request to the cluster. + # Search all the spot instance requests, and copy any tags from the spot + # instance request to the cluster. spot_instance_requests = conn.get_all_spot_instance_requests() for req in spot_instance_requests: if req.state != u'active': @@ -509,7 +526,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): for res in reservations: active = [i for i in res.instances if is_active(i)] for instance in active: - if (instance.tags.get(u'Name') == None): + if (instance.tags.get(u'Name') is None): tag_instance(instance, name) # Now proceed to detect master and slaves instances. reservations = conn.get_all_instances() @@ -529,13 +546,16 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): return (master_nodes, slave_nodes) else: if master_nodes == [] and slave_nodes != []: - print >> sys.stderr, "ERROR: Could not find master in with name " + cluster_name + "-master" + print >> sys.stderr, "ERROR: Could not find master in with name " + \ + cluster_name + "-master" else: print >> sys.stderr, "ERROR: Could not find any existing cluster" sys.exit(1) # Deploy configuration files and run setup scripts on a newly launched # or started EC2 cluster. + + def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): master = master_nodes[0].public_dns_name if deploy_ssh_key: @@ -828,6 +848,12 @@ def get_partition(total, num_partitions, current_partitions): def real_main(): (opts, action, cluster_name) = parse_args() + + # Input parameter validation + if opts.ebs_vol_num > 8: + print >> stderr, "ebs-vol-num cannot be greater than 8" + sys.exit(1) + try: conn = ec2.connect_to_region(opts.region) except Exception as e: @@ -873,7 +899,8 @@ def real_main(): if opts.security_group_prefix is None: group_names = [cluster_name + "-master", cluster_name + "-slaves"] else: - group_names = [opts.security_group_prefix + "-master", opts.security_group_prefix + "-slaves"] + group_names = [opts.security_group_prefix + "-master", + opts.security_group_prefix + "-slaves"] attempt = 1 while attempt <= 3: diff --git a/examples/pom.xml b/examples/pom.xml index 9b12cb0c29c9f..3f46c40464d3b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/examples/src/main/python/avro_inputformat.py b/examples/src/main/python/avro_inputformat.py index e902ae29753c0..cfda8d8327aa3 100644 --- a/examples/src/main/python/avro_inputformat.py +++ b/examples/src/main/python/avro_inputformat.py @@ -23,7 +23,8 @@ Read data file users.avro in local Spark distro: $ cd $SPARK_HOME -$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +$ ./bin/spark-submit --driver-class-path /path/to/example/jar \ +> ./examples/src/main/python/avro_inputformat.py \ > examples/src/main/resources/users.avro {u'favorite_color': None, u'name': u'Alyssa', u'favorite_numbers': [3, 9, 15, 20]} {u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} @@ -40,7 +41,8 @@ ] } -$ ./bin/spark-submit --driver-class-path /path/to/example/jar ./examples/src/main/python/avro_inputformat.py \ +$ ./bin/spark-submit --driver-class-path /path/to/example/jar \ +> ./examples/src/main/python/avro_inputformat.py \ > examples/src/main/resources/users.avro examples/src/main/resources/user.avsc {u'favorite_color': None, u'name': u'Alyssa'} {u'favorite_color': u'red', u'name': u'Ben'} @@ -51,8 +53,10 @@ Usage: avro_inputformat [reader_schema_file] Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/avro_inputformat.py [reader_schema_file] - Assumes you have Avro data stored in . Reader schema can be optionally specified in [reader_schema_file]. + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/avro_inputformat.py [reader_schema_file] + Assumes you have Avro data stored in . Reader schema can be optionally specified + in [reader_schema_file]. """ exit(-1) @@ -62,9 +66,10 @@ conf = None if len(sys.argv) == 3: schema_rdd = sc.textFile(sys.argv[2], 1).collect() - conf = {"avro.schema.input.key" : reduce(lambda x, y: x+y, schema_rdd)} + conf = {"avro.schema.input.key": reduce(lambda x, y: x + y, schema_rdd)} - avro_rdd = sc.newAPIHadoopFile(path, + avro_rdd = sc.newAPIHadoopFile( + path, "org.apache.avro.mapreduce.AvroKeyInputFormat", "org.apache.avro.mapred.AvroKey", "org.apache.hadoop.io.NullWritable", diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py index e4a897f61e39d..05f34b74df45a 100644 --- a/examples/src/main/python/cassandra_inputformat.py +++ b/examples/src/main/python/cassandra_inputformat.py @@ -51,7 +51,8 @@ Usage: cassandra_inputformat Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_inputformat.py + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/cassandra_inputformat.py Assumes you have some data in Cassandra already, running on , in and """ exit(-1) @@ -61,12 +62,12 @@ cf = sys.argv[3] sc = SparkContext(appName="CassandraInputFormat") - conf = {"cassandra.input.thrift.address":host, - "cassandra.input.thrift.port":"9160", - "cassandra.input.keyspace":keyspace, - "cassandra.input.columnfamily":cf, - "cassandra.input.partitioner.class":"Murmur3Partitioner", - "cassandra.input.page.row.size":"3"} + conf = {"cassandra.input.thrift.address": host, + "cassandra.input.thrift.port": "9160", + "cassandra.input.keyspace": keyspace, + "cassandra.input.columnfamily": cf, + "cassandra.input.partitioner.class": "Murmur3Partitioner", + "cassandra.input.page.row.size": "3"} cass_rdd = sc.newAPIHadoopRDD( "org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat", "java.util.Map", diff --git a/examples/src/main/python/cassandra_outputformat.py b/examples/src/main/python/cassandra_outputformat.py index 836c35b5c6794..d144539e58b8f 100644 --- a/examples/src/main/python/cassandra_outputformat.py +++ b/examples/src/main/python/cassandra_outputformat.py @@ -50,7 +50,8 @@ Usage: cassandra_outputformat Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_outputformat.py + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/cassandra_outputformat.py Assumes you have created the following table in Cassandra already, running on , in . @@ -67,16 +68,16 @@ cf = sys.argv[3] sc = SparkContext(appName="CassandraOutputFormat") - conf = {"cassandra.output.thrift.address":host, - "cassandra.output.thrift.port":"9160", - "cassandra.output.keyspace":keyspace, - "cassandra.output.partitioner.class":"Murmur3Partitioner", - "cassandra.output.cql":"UPDATE " + keyspace + "." + cf + " SET fname = ?, lname = ?", - "mapreduce.output.basename":cf, - "mapreduce.outputformat.class":"org.apache.cassandra.hadoop.cql3.CqlOutputFormat", - "mapreduce.job.output.key.class":"java.util.Map", - "mapreduce.job.output.value.class":"java.util.List"} - key = {"user_id" : int(sys.argv[4])} + conf = {"cassandra.output.thrift.address": host, + "cassandra.output.thrift.port": "9160", + "cassandra.output.keyspace": keyspace, + "cassandra.output.partitioner.class": "Murmur3Partitioner", + "cassandra.output.cql": "UPDATE " + keyspace + "." + cf + " SET fname = ?, lname = ?", + "mapreduce.output.basename": cf, + "mapreduce.outputformat.class": "org.apache.cassandra.hadoop.cql3.CqlOutputFormat", + "mapreduce.job.output.key.class": "java.util.Map", + "mapreduce.job.output.value.class": "java.util.List"} + key = {"user_id": int(sys.argv[4])} sc.parallelize([(key, sys.argv[5:])]).saveAsNewAPIHadoopDataset( conf=conf, keyConverter="org.apache.spark.examples.pythonconverters.ToCassandraCQLKeyConverter", diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py index befacee0dea56..3b16010f1cb97 100644 --- a/examples/src/main/python/hbase_inputformat.py +++ b/examples/src/main/python/hbase_inputformat.py @@ -51,7 +51,8 @@ Usage: hbase_inputformat Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_inputformat.py
    + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/hbase_inputformat.py
    Assumes you have some data in HBase already, running on , in
    """ exit(-1) @@ -61,12 +62,15 @@ sc = SparkContext(appName="HBaseInputFormat") conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} + keyConv = "org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter" + valueConv = "org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter" + hbase_rdd = sc.newAPIHadoopRDD( "org.apache.hadoop.hbase.mapreduce.TableInputFormat", "org.apache.hadoop.hbase.io.ImmutableBytesWritable", "org.apache.hadoop.hbase.client.Result", - keyConverter="org.apache.spark.examples.pythonconverters.ImmutableBytesWritableToStringConverter", - valueConverter="org.apache.spark.examples.pythonconverters.HBaseResultToStringConverter", + keyConverter=keyConv, + valueConverter=valueConv, conf=conf) output = hbase_rdd.collect() for (k, v) in output: diff --git a/examples/src/main/python/hbase_outputformat.py b/examples/src/main/python/hbase_outputformat.py index 49bbc5aebdb0b..abb425b1f886a 100644 --- a/examples/src/main/python/hbase_outputformat.py +++ b/examples/src/main/python/hbase_outputformat.py @@ -44,8 +44,10 @@ Usage: hbase_outputformat
    Run with example jar: - ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_outputformat.py - Assumes you have created
    with column family in HBase running on already + ./bin/spark-submit --driver-class-path /path/to/example/jar \ + /path/to/examples/hbase_outputformat.py + Assumes you have created
    with column family in HBase + running on already """ exit(-1) @@ -55,13 +57,15 @@ conf = {"hbase.zookeeper.quorum": host, "hbase.mapred.outputtable": table, - "mapreduce.outputformat.class" : "org.apache.hadoop.hbase.mapreduce.TableOutputFormat", - "mapreduce.job.output.key.class" : "org.apache.hadoop.hbase.io.ImmutableBytesWritable", - "mapreduce.job.output.value.class" : "org.apache.hadoop.io.Writable"} + "mapreduce.outputformat.class": "org.apache.hadoop.hbase.mapreduce.TableOutputFormat", + "mapreduce.job.output.key.class": "org.apache.hadoop.hbase.io.ImmutableBytesWritable", + "mapreduce.job.output.value.class": "org.apache.hadoop.io.Writable"} + keyConv = "org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter" + valueConv = "org.apache.spark.examples.pythonconverters.StringListToPutConverter" sc.parallelize([sys.argv[3:]]).map(lambda x: (x[0], x)).saveAsNewAPIHadoopDataset( conf=conf, - keyConverter="org.apache.spark.examples.pythonconverters.StringToImmutableBytesWritableConverter", - valueConverter="org.apache.spark.examples.pythonconverters.StringListToPutConverter") + keyConverter=keyConv, + valueConverter=valueConv) sc.stop() diff --git a/examples/src/main/python/mllib/correlations.py b/examples/src/main/python/mllib/correlations.py index 6b16a56e44af7..4218eca822a99 100755 --- a/examples/src/main/python/mllib/correlations.py +++ b/examples/src/main/python/mllib/correlations.py @@ -28,7 +28,7 @@ if __name__ == "__main__": - if len(sys.argv) not in [1,2]: + if len(sys.argv) not in [1, 2]: print >> sys.stderr, "Usage: correlations ()" exit(-1) sc = SparkContext(appName="PythonCorrelations") diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py index 6e4a4a0cb6be0..61ea4e06ecf3a 100755 --- a/examples/src/main/python/mllib/decision_tree_runner.py +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -21,7 +21,9 @@ This example requires NumPy (http://www.numpy.org/). """ -import numpy, os, sys +import numpy +import os +import sys from operator import add @@ -127,7 +129,7 @@ def usage(): (reindexedData, origToNewLabels) = reindexClassLabels(points) # Train a classifier. - categoricalFeaturesInfo={} # no categorical features + categoricalFeaturesInfo = {} # no categorical features model = DecisionTree.trainClassifier(reindexedData, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) # Print learned tree and stats. diff --git a/examples/src/main/python/mllib/random_rdd_generation.py b/examples/src/main/python/mllib/random_rdd_generation.py index b388d8d83fb86..1e8892741e714 100755 --- a/examples/src/main/python/mllib/random_rdd_generation.py +++ b/examples/src/main/python/mllib/random_rdd_generation.py @@ -32,8 +32,8 @@ sc = SparkContext(appName="PythonRandomRDDGeneration") - numExamples = 10000 # number of examples to generate - fraction = 0.1 # fraction of data to sample + numExamples = 10000 # number of examples to generate + fraction = 0.1 # fraction of data to sample # Example: RandomRDDs.normalRDD normalRDD = RandomRDDs.normalRDD(sc, numExamples) @@ -45,7 +45,7 @@ print # Example: RandomRDDs.normalVectorRDD - normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows = numExamples, numCols = 2) + normalVectorRDD = RandomRDDs.normalVectorRDD(sc, numRows=numExamples, numCols=2) print 'Generated RDD of %d examples of length-2 vectors.' % normalVectorRDD.count() print ' First 5 samples:' for sample in normalVectorRDD.take(5): diff --git a/examples/src/main/python/mllib/sampled_rdds.py b/examples/src/main/python/mllib/sampled_rdds.py index ec64a5978c672..92af3af5ebd1e 100755 --- a/examples/src/main/python/mllib/sampled_rdds.py +++ b/examples/src/main/python/mllib/sampled_rdds.py @@ -36,7 +36,7 @@ sc = SparkContext(appName="PythonSampledRDDs") - fraction = 0.1 # fraction of data to sample + fraction = 0.1 # fraction of data to sample examples = MLUtils.loadLibSVMFile(sc, datapath) numExamples = examples.count() @@ -49,9 +49,9 @@ expectedSampleSize = int(numExamples * fraction) print 'Sampling RDD using fraction %g. Expected sample size = %d.' \ % (fraction, expectedSampleSize) - sampledRDD = examples.sample(withReplacement = True, fraction = fraction) + sampledRDD = examples.sample(withReplacement=True, fraction=fraction) print ' RDD.sample(): sample has %d examples' % sampledRDD.count() - sampledArray = examples.takeSample(withReplacement = True, num = expectedSampleSize) + sampledArray = examples.takeSample(withReplacement=True, num=expectedSampleSize) print ' RDD.takeSample(): sample has %d examples' % len(sampledArray) print @@ -66,7 +66,7 @@ fractions = {} for k in keyCountsA.keys(): fractions[k] = fraction - sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement = True, fractions = fractions) + sampledByKeyRDD = keyedRDD.sampleByKey(withReplacement=True, fractions=fractions) keyCountsB = sampledByKeyRDD.countByKey() sizeB = sum(keyCountsB.values()) print ' Sampled %d examples using approximate stratified sampling (by label). ==> Sample' \ diff --git a/examples/src/main/python/pi.py b/examples/src/main/python/pi.py index fc37459dc74aa..ee9036adfa281 100755 --- a/examples/src/main/python/pi.py +++ b/examples/src/main/python/pi.py @@ -35,7 +35,7 @@ def f(_): y = random() * 2 - 1 return 1 if x ** 2 + y ** 2 < 1 else 0 - count = sc.parallelize(xrange(1, n+1), slices).map(f).reduce(add) + count = sc.parallelize(xrange(1, n + 1), slices).map(f).reduce(add) print "Pi is roughly %f" % (4.0 * count / n) sc.stop() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala similarity index 98% rename from graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala rename to examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index c1513a00453cf..c4317a6aec798 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -15,12 +15,13 @@ * limitations under the License. */ -package org.apache.spark.graphx.lib +package org.apache.spark.examples.graphx import scala.collection.mutable import org.apache.spark._ import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ +import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.PartitionStrategy._ /** diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala index 6ef3b62dcbedc..bdc8fa7f99f2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/LiveJournalPageRank.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples.graphx import org.apache.spark.SparkContext._ import org.apache.spark._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib.Analytics +import org.apache.spark.examples.graphx.Analytics /** * Uses GraphX to run PageRank on a LiveJournal social network graph. Download the dataset from diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 551c339b19523..5f35a5836462e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -38,12 +38,13 @@ object SynthBenchmark { * Options: * -app "pagerank" or "cc" for pagerank or connected components. (Default: pagerank) * -niters the number of iterations of pagerank to use (Default: 10) - * -numVertices the number of vertices in the graph (Default: 1000000) + * -nverts the number of vertices in the graph (Default: 1000000) * -numEPart the number of edge partitions in the graph (Default: number of cores) * -partStrategy the graph partitioning strategy to use * -mu the mean parameter for the log-normal graph (Default: 4.0) * -sigma the stdev parameter for the log-normal graph (Default: 1.3) * -degFile the local file to save the degree information (Default: Empty) + * -seed seed to use for RNGs (Default: -1, picks seed randomly) */ def main(args: Array[String]) { val options = args.map { @@ -62,6 +63,7 @@ object SynthBenchmark { var mu: Double = 4.0 var sigma: Double = 1.3 var degFile: String = "" + var seed: Int = -1 options.foreach { case ("app", v) => app = v @@ -72,6 +74,7 @@ object SynthBenchmark { case ("mu", v) => mu = v.toDouble case ("sigma", v) => sigma = v.toDouble case ("degFile", v) => degFile = v + case ("seed", v) => seed = v.toInt case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) } @@ -85,7 +88,7 @@ object SynthBenchmark { // Create the graph println(s"Creating graph...") val unpartitionedGraph = GraphGenerators.logNormalGraph(sc, numVertices, - numEPart.getOrElse(sc.defaultParallelism), mu, sigma) + numEPart.getOrElse(sc.defaultParallelism), mu, sigma, seed) // Repartition the graph val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_)).cache() @@ -113,7 +116,7 @@ object SynthBenchmark { println(s"Total PageRank = $totalPR") } else if (app == "cc") { println("Running Connected Components") - val numComponents = graph.connectedComponents.vertices.map(_._2).distinct() + val numComponents = graph.connectedComponents.vertices.map(_._2).distinct().count() println(s"Number of components = $numComponents") } val runTime = System.currentTimeMillis() - startTime diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index b345276b08ba3..ac291bd4fde20 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index f71f6b6c4f931..7d31e32283d88 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala index 73dffef953309..6ee7ac974b4a0 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumeStreamSuite.scala @@ -109,11 +109,11 @@ class FlumeStreamSuite extends TestSuiteBase { } class CompressionChannelFactory(compressionLevel: Int) extends NioClientSocketChannelFactory { - override def newChannel(pipeline:ChannelPipeline) : SocketChannel = { - var encoder : ZlibEncoder = new ZlibEncoder(compressionLevel); - pipeline.addFirst("deflater", encoder); - pipeline.addFirst("inflater", new ZlibDecoder()); - super.newChannel(pipeline); + override def newChannel(pipeline: ChannelPipeline): SocketChannel = { + val encoder = new ZlibEncoder(compressionLevel) + pipeline.addFirst("deflater", encoder) + pipeline.addFirst("inflater", new ZlibDecoder()) + super.newChannel(pipeline) } } } diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 4e2275ab238f7..2067c473f0e3f 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index dc48a08c93de2..371f1f1e9d39a 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index b93ad016f84f0..1d7dd49d15c22 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 22c1fff23d9a2..7e48968feb3bc 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 5308bb4e440ea..8658ecf5abfab 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index a54b34235dfb4..560244ad93369 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index 1a710d7b18c6f..aa917d0575c4c 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -75,7 +75,7 @@ * onto the Kinesis stream. * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. */ -public final class JavaKinesisWordCountASL { +public final class JavaKinesisWordCountASL { // needs to be public for access from run-example private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); @@ -87,10 +87,10 @@ public static void main(String[] args) { /* Check that all required args were passed in. */ if (args.length < 2) { System.err.println( - "|Usage: KinesisWordCount \n" + - "| is the name of the Kinesis stream\n" + - "| is the endpoint of the Kinesis service\n" + - "| (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); + "Usage: JavaKinesisWordCountASL \n" + + " is the name of the Kinesis stream\n" + + " is the endpoint of the Kinesis service\n" + + " (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); System.exit(1); } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index d03edf8b30a9f..fffd90de08240 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -69,7 +69,7 @@ import org.apache.log4j.Level * dummy data onto the Kinesis stream. * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. */ -object KinesisWordCountASL extends Logging { +private object KinesisWordCountASL extends Logging { def main(args: Array[String]) { /* Check that all required args were passed in. */ if (args.length < 2) { @@ -154,7 +154,7 @@ object KinesisWordCountASL extends Logging { * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ * https://kinesis.us-east-1.amazonaws.com 10 5 */ -object KinesisWordCountProducerASL { +private object KinesisWordCountProducerASL { def main(args: Array[String]) { if (args.length < 4) { System.err.println("Usage: KinesisWordCountProducerASL " + @@ -235,7 +235,7 @@ object KinesisWordCountProducerASL { * Utility functions for Spark Streaming examples. * This has been lifted from the examples/ project to remove the circular dependency. */ -object StreamingExamples extends Logging { +private[streaming] object StreamingExamples extends Logging { /** Set reasonable logging levels for streaming if the user has not configured log4j. */ def setStreamingLogLevels() { diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala index 713cac0e293c0..96f4399accd3a 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -35,7 +35,7 @@ import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionIn object KinesisUtils { /** * Create an InputDStream that pulls messages from a Kinesis stream. - * + * :: Experimental :: * @param ssc StreamingContext object * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) @@ -52,6 +52,7 @@ object KinesisUtils { * * @return ReceiverInputDStream[Array[Byte]] */ + @Experimental def createStream( ssc: StreamingContext, streamName: String, @@ -65,9 +66,8 @@ object KinesisUtils { /** * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. - * + * :: Experimental :: * @param jssc Java StreamingContext object - * @param ssc StreamingContext object * @param streamName Kinesis stream name * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. @@ -83,6 +83,7 @@ object KinesisUtils { * * @return JavaReceiverInputDStream[Array[Byte]] */ + @Experimental def createStream( jssc: JavaStreamingContext, streamName: String, diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index a5b162a0482e4..71a078d58a8d8 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 6dd52fc618b1e..3f49b1d63b6e1 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index 899a3cbd62b60..5bcb96b136ed7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -37,7 +37,15 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag]( val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) extends RDD[Edge[ED]](partitionsRDD.context, List(new OneToOneDependency(partitionsRDD))) { - partitionsRDD.setName("EdgeRDD") + override def setName(_name: String): this.type = { + if (partitionsRDD.name != null) { + partitionsRDD.setName(partitionsRDD.name + ", " + _name) + } else { + partitionsRDD.setName(_name) + } + this + } + setName("EdgeRDD") override protected def getPartitions: Array[Partition] = partitionsRDD.partitions diff --git a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala index 5e7e72a764cc8..13033fee0e6b5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/PartitionStrategy.scala @@ -91,7 +91,7 @@ object PartitionStrategy { case object EdgePartition1D extends PartitionStrategy { override def getPartition(src: VertexId, dst: VertexId, numParts: PartitionID): PartitionID = { val mixingPrime: VertexId = 1125899906842597L - (math.abs(src) * mixingPrime).toInt % numParts + (math.abs(src * mixingPrime) % numParts).toInt } } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 60149548ab852..b8309289fe475 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -40,7 +40,7 @@ object GraphGenerators { val RMATd = 0.25 /** - * Generate a graph whose vertex out degree is log normal. + * Generate a graph whose vertex out degree distribution is log normal. * * The default values for mu and sigma are taken from the Pregel paper: * @@ -48,33 +48,36 @@ object GraphGenerators { * Ilan Horn, Naty Leiser, and Grzegorz Czajkowski. 2010. * Pregel: a system for large-scale graph processing. SIGMOD '10. * - * @param sc - * @param numVertices - * @param mu - * @param sigma - * @return + * If the seed is -1 (default), a random seed is chosen. Otherwise, use + * the user-specified seed. + * + * @param sc Spark Context + * @param numVertices number of vertices in generated graph + * @param numEParts (optional) number of partitions + * @param mu (optional, default: 4.0) mean of out-degree distribution + * @param sigma (optional, default: 1.3) standard deviation of out-degree distribution + * @param seed (optional, default: -1) seed for RNGs, -1 causes a random seed to be chosen + * @return Graph object */ - def logNormalGraph(sc: SparkContext, numVertices: Int, numEParts: Int, - mu: Double = 4.0, sigma: Double = 1.3): Graph[Long, Int] = { - val vertices = sc.parallelize(0 until numVertices, numEParts).map { src => - // Initialize the random number generator with the source vertex id - val rand = new Random(src) - val degree = math.min(numVertices.toLong, math.exp(rand.nextGaussian() * sigma + mu).toLong) - (src.toLong, degree) + def logNormalGraph( + sc: SparkContext, numVertices: Int, numEParts: Int = 0, mu: Double = 4.0, + sigma: Double = 1.3, seed: Long = -1): Graph[Long, Int] = { + + val evalNumEParts = if (numEParts == 0) sc.defaultParallelism else numEParts + + // Enable deterministic seeding + val seedRand = if (seed == -1) new Random() else new Random(seed) + val seed1 = seedRand.nextInt() + val seed2 = seedRand.nextInt() + + val vertices: RDD[(VertexId, Long)] = sc.parallelize(0 until numVertices, evalNumEParts).map { + src => (src, sampleLogNormal(mu, sigma, numVertices, seed = (seed1 ^ src))) } + val edges = vertices.flatMap { case (src, degree) => - new Iterator[Edge[Int]] { - // Initialize the random number generator with the source vertex id - val rand = new Random(src) - var i = 0 - override def hasNext(): Boolean = { i < degree } - override def next(): Edge[Int] = { - val nextEdge = Edge[Int](src, rand.nextInt(numVertices), i) - i += 1 - nextEdge - } - } + generateRandomEdges(src.toInt, degree.toInt, numVertices, seed = (seed2 ^ src)) } + Graph(vertices, edges, 0) } @@ -82,9 +85,10 @@ object GraphGenerators { // the edge data is the weight (default 1) val RMATc = 0.15 - def generateRandomEdges(src: Int, numEdges: Int, maxVertexId: Int): Array[Edge[Int]] = { - val rand = new Random() - Array.fill(maxVertexId) { Edge[Int](src, rand.nextInt(maxVertexId), 1) } + def generateRandomEdges( + src: Int, numEdges: Int, maxVertexId: Int, seed: Long = -1): Array[Edge[Int]] = { + val rand = if (seed == -1) new Random() else new Random(seed) + Array.fill(numEdges) { Edge[Int](src, rand.nextInt(maxVertexId), 1) } } /** @@ -97,9 +101,12 @@ object GraphGenerators { * @param mu the mean of the normal distribution * @param sigma the standard deviation of the normal distribution * @param maxVal exclusive upper bound on the value of the sample + * @param seed optional seed */ - private def sampleLogNormal(mu: Double, sigma: Double, maxVal: Int): Int = { - val rand = new Random() + private[spark] def sampleLogNormal( + mu: Double, sigma: Double, maxVal: Int, seed: Long = -1): Int = { + val rand = if (seed == -1) new Random() else new Random(seed) + val sigmaSq = sigma * sigma val m = math.exp(mu + sigmaSq / 2.0) // expm1 is exp(m)-1 with better accuracy for tiny m diff --git a/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala new file mode 100644 index 0000000000000..b346d4db2ef96 --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/GraphGeneratorsSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.util + +import org.scalatest.FunSuite + +import org.apache.spark.graphx.LocalSparkContext + +class GraphGeneratorsSuite extends FunSuite with LocalSparkContext { + + test("GraphGenerators.generateRandomEdges") { + val src = 5 + val numEdges10 = 10 + val numEdges20 = 20 + val maxVertexId = 100 + + val edges10 = GraphGenerators.generateRandomEdges(src, numEdges10, maxVertexId) + assert(edges10.length == numEdges10) + + val correctSrc = edges10.forall(e => e.srcId == src) + assert(correctSrc) + + val correctWeight = edges10.forall(e => e.attr == 1) + assert(correctWeight) + + val correctRange = edges10.forall(e => e.dstId >= 0 && e.dstId <= maxVertexId) + assert(correctRange) + + val edges20 = GraphGenerators.generateRandomEdges(src, numEdges20, maxVertexId) + assert(edges20.length == numEdges20) + + val edges10_round1 = + GraphGenerators.generateRandomEdges(src, numEdges10, maxVertexId, seed = 12345) + val edges10_round2 = + GraphGenerators.generateRandomEdges(src, numEdges10, maxVertexId, seed = 12345) + assert(edges10_round1.zip(edges10_round2).forall { case (e1, e2) => + e1.srcId == e2.srcId && e1.dstId == e2.dstId && e1.attr == e2.attr + }) + + val edges10_round3 = + GraphGenerators.generateRandomEdges(src, numEdges10, maxVertexId, seed = 3467) + assert(!edges10_round1.zip(edges10_round3).forall { case (e1, e2) => + e1.srcId == e2.srcId && e1.dstId == e2.dstId && e1.attr == e2.attr + }) + } + + test("GraphGenerators.sampleLogNormal") { + val mu = 4.0 + val sigma = 1.3 + val maxVal = 100 + + val dstId = GraphGenerators.sampleLogNormal(mu, sigma, maxVal) + assert(dstId < maxVal) + + val dstId_round1 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345) + val dstId_round2 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 12345) + assert(dstId_round1 == dstId_round2) + + val dstId_round3 = GraphGenerators.sampleLogNormal(mu, sigma, maxVal, 789) + assert(dstId_round1 != dstId_round3) + } + + test("GraphGenerators.logNormalGraph") { + withSpark { sc => + val mu = 4.0 + val sigma = 1.3 + val numVertices100 = 100 + + val graph = GraphGenerators.logNormalGraph(sc, numVertices100, mu = mu, sigma = sigma) + assert(graph.vertices.count() == numVertices100) + + val graph_round1 = + GraphGenerators.logNormalGraph(sc, numVertices100, mu = mu, sigma = sigma, seed = 12345) + val graph_round2 = + GraphGenerators.logNormalGraph(sc, numVertices100, mu = mu, sigma = sigma, seed = 12345) + + val graph_round1_edges = graph_round1.edges.collect() + val graph_round2_edges = graph_round2.edges.collect() + + assert(graph_round1_edges.zip(graph_round2_edges).forall { case (e1, e2) => + e1.srcId == e2.srcId && e1.dstId == e2.dstId && e1.attr == e2.attr + }) + + val graph_round3 = + GraphGenerators.logNormalGraph(sc, numVertices100, mu = mu, sigma = sigma, seed = 567) + + val graph_round3_edges = graph_round3.edges.collect() + + assert(!graph_round1_edges.zip(graph_round3_edges).forall { case (e1, e2) => + e1.srcId == e2.srcId && e1.dstId == e2.dstId && e1.attr == e2.attr + }) + } + } + +} diff --git a/make-distribution.sh b/make-distribution.sh index ee1399071112d..9b012b9222db4 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -28,7 +28,7 @@ set -o pipefail set -e # Figure out where the Spark framework is installed -FWDIR="$(cd `dirname $0`; pwd)" +FWDIR="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$FWDIR/dist" SPARK_TACHYON=false @@ -50,7 +50,8 @@ while (( "$#" )); do case $1 in --hadoop) echo "Error: '--hadoop' is no longer supported:" - echo "Error: use Maven options -Phadoop.version and -Pyarn.version" + echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." + echo "Error: Related profiles include hadoop-0.23, hdaoop-2.2, hadoop-2.3 and hadoop-2.4." exit_with_usage ;; --with-yarn) @@ -219,10 +220,10 @@ if [ "$SPARK_TACHYON" == "true" ]; then wget "$TACHYON_URL" tar xf "tachyon-${TACHYON_VERSION}-bin.tar.gz" - cp "tachyon-${TACHYON_VERSION}/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" + cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" - cp -r "tachyon-${TACHYON_VERSION}"/src/main/java/tachyon/web/resources "$DISTDIR/tachyon/src/main/java/tachyon/web" + cp -r "tachyon-${TACHYON_VERSION}"/core/src/main/java/tachyon/web "$DISTDIR/tachyon/src/main/java/tachyon/web" if [[ `uname -a` == Darwin* ]]; then # need to run sed differently on osx diff --git a/mllib/pom.xml b/mllib/pom.xml index c7a1e2ae75c84..a5eeef88e9d62 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index fdd67160114ca..45dbf6044fcc5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -128,7 +128,7 @@ class LeastSquaresGradient extends Gradient { class HingeGradient extends Gradient { override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { val dotProduct = dot(data, weights) - // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) { @@ -146,7 +146,7 @@ class HingeGradient extends Gradient { weights: Vector, cumGradient: Vector): Double = { val dotProduct = dot(data, weights) - // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 5cdd258f6c20b..dd766c12d28a4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -28,8 +28,9 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TimeTracker, TreePoint} +import org.apache.spark.mllib.tree.impl._ import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} +import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -65,36 +66,41 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo val retaggedInput = input.retag(classOf[LabeledPoint]) val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) logDebug("algo = " + strategy.algo) + logDebug("maxBins = " + metadata.maxBins) // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) - val numBins = bins(0).length timer.stop("findSplitsBins") - logDebug("numBins = " + numBins) + logDebug("numBins: feature: number of bins") + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) // Bin feature values (TreePoint representation). // Cache input RDD for speedup during multiple passes. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) .persist(StorageLevel.MEMORY_AND_DISK) - val numFeatures = metadata.numFeatures // depth of the decision tree val maxDepth = strategy.maxDepth - // the max number of nodes possible given the depth of the tree - val maxNumNodes = (2 << maxDepth) - 1 + require(maxDepth <= 30, + s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") + // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1 + val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1) // Initialize an array to hold parent impurity calculations for each node. - val parentImpurities = new Array[Double](maxNumNodes) + val parentImpurities = new Array[Double](maxNumNodesPlus1) // dummy value for top node (updated during first split calculation) - val nodes = new Array[Node](maxNumNodes) + val nodes = new Array[Node](maxNumNodesPlus1) // Calculate level for single group construction // Max memory usage for aggregates val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024 logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - val numElementsPerNode = DecisionTree.getElementsPerNode(metadata, numBins) + // TODO: Calculate memory usage more precisely. + val numElementsPerNode = DecisionTree.getElementsPerNode(metadata) logDebug("numElementsPerNode = " + numElementsPerNode) val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array @@ -124,26 +130,29 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo // Find best split for all nodes at a level. timer.start("findBestSplits") - val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities, - metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) + val splitsStatsForLevel: Array[(Split, InformationGainStats)] = + DecisionTree.findBestSplits(treeInput, parentImpurities, + metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer) timer.stop("findBestSplits") - val levelNodeIndexOffset = (1 << level) - 1 + val levelNodeIndexOffset = Node.startIndexInLevel(level) for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { val nodeIndex = levelNodeIndexOffset + index - val isLeftChild = level != 0 && nodeIndex % 2 == 1 - val parentNodeIndex = if (isLeftChild) { // -1 for root node - (nodeIndex - 1) / 2 - } else { - (nodeIndex - 2) / 2 - } + // Extract info for this node (index) at the current level. timer.start("extractNodeInfo") - extractNodeInfo(nodeSplitStats, level, index, nodes) + val split = nodeSplitStats._1 + val stats = nodeSplitStats._2 + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) + val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + nodes(nodeIndex) = node timer.stop("extractNodeInfo") + if (level != 0) { // Set parent. - if (isLeftChild) { + val parentNodeIndex = Node.parentIndex(nodeIndex) + if (Node.isLeftChild(nodeIndex)) { nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex)) } else { nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex)) @@ -151,11 +160,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo } // Extract info for nodes at the next lower level. timer.start("extractInfoForLowerLevels") - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities) + if (level < maxDepth) { + val leftChildIndex = Node.leftChildIndex(nodeIndex) + val leftImpurity = stats.leftImpurity + logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity) + parentImpurities(leftChildIndex) = leftImpurity + + val rightChildIndex = Node.rightChildIndex(nodeIndex) + val rightImpurity = stats.rightImpurity + logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity) + parentImpurities(rightChildIndex) = rightImpurity + } timer.stop("extractInfoForLowerLevels") - logDebug("final best split = " + nodeSplitStats._1) + logDebug("final best split = " + split) } - require((1 << level) == splitsStatsForLevel.length) + require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length) // Check whether all the nodes at the current level at leaves. val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) logDebug("all leaf = " + allLeaf) @@ -171,7 +190,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("#####################################") // Initialize the top or root node of the tree. - val topNode = nodes(0) + val topNode = nodes(1) // Build the full tree using the node info calculated in the level-wise best split calculations. topNode.build(nodes) @@ -183,47 +202,6 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo new DecisionTreeModel(topNode, strategy.algo) } - /** - * Extract the decision tree node information for the given tree level and node index - */ - private def extractNodeInfo( - nodeSplitStats: (Split, InformationGainStats), - level: Int, - index: Int, - nodes: Array[Node]): Unit = { - val split = nodeSplitStats._1 - val stats = nodeSplitStats._2 - val nodeIndex = (1 << level) - 1 + index - val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth) - val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - nodes(nodeIndex) = node - } - - /** - * Extract the decision tree node information for the children of the node - */ - private def extractInfoForLowerLevels( - level: Int, - index: Int, - maxDepth: Int, - nodeSplitStats: (Split, InformationGainStats), - parentImpurities: Array[Double]): Unit = { - - if (level >= maxDepth) { - return - } - - val leftNodeIndex = (2 << level) - 1 + 2 * index - val leftImpurity = nodeSplitStats._2.leftImpurity - logDebug("leftNodeIndex = " + leftNodeIndex + ", impurity = " + leftImpurity) - parentImpurities(leftNodeIndex) = leftImpurity - - val rightNodeIndex = leftNodeIndex + 1 - val rightImpurity = nodeSplitStats._2.rightImpurity - logDebug("rightNodeIndex = " + rightNodeIndex + ", impurity = " + rightImpurity) - parentImpurities(rightNodeIndex) = rightImpurity - } } object DecisionTree extends Serializable with Logging { @@ -425,9 +403,6 @@ object DecisionTree extends Serializable with Logging { impurity, maxDepth, maxBins) } - - private val InvalidBinIndex = -1 - /** * Returns an array of optimal splits for all nodes at a given level. Splits the task into * multiple groups if the level-wise training task could lead to memory overflow. @@ -436,12 +411,12 @@ object DecisionTree extends Serializable with Logging { * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param splits possible splits for all features - * @param bins possible bins for all features + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. * @return array (over nodes) of splits with best split for each node at a given level. */ - protected[tree] def findBestSplits( + private[tree] def findBestSplits( input: RDD[TreePoint], parentImpurities: Array[Double], metadata: DecisionTreeMetadata, @@ -474,6 +449,138 @@ object DecisionTree extends Serializable with Logging { } } + /** + * Get the node index corresponding to this data point. + * This function mimics prediction, passing an example from the root node down to a node + * at the current level being trained; that node's index is returned. + * + * @param node Node in tree from which to classify the given data point. + * @param binnedFeatures Binned feature vector for data point. + * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param unorderedFeatures Set of indices of unordered features. + * @return Leaf index if the data point reaches a leaf. + * Otherwise, last node reachable in tree matching this example. + * Note: This is the global node index, i.e., the index used in the tree. + * This index is different from the index used during training a particular + * set of nodes in a (level, group). + */ + private def predictNodeIndex( + node: Node, + binnedFeatures: Array[Int], + bins: Array[Array[Bin]], + unorderedFeatures: Set[Int]): Int = { + if (node.isLeaf) { + node.id + } else { + val featureIndex = node.split.get.feature + val splitLeft = node.split.get.featureType match { + case Continuous => { + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] + // We do not need to check lowSplit since bins are separated by splits. + featureValueUpperBound <= node.split.get.threshold + } + case Categorical => { + val featureValue = binnedFeatures(featureIndex) + node.split.get.categories.contains(featureValue) + } + case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") + } + if (node.leftNode.isEmpty || node.rightNode.isEmpty) { + // Return index from next layer of nodes to train + if (splitLeft) { + Node.leftChildIndex(node.id) + } else { + Node.rightChildIndex(node.id) + } + } else { + if (splitLeft) { + predictNodeIndex(node.leftNode.get, binnedFeatures, bins, unorderedFeatures) + } else { + predictNodeIndex(node.rightNode.get, binnedFeatures, bins, unorderedFeatures) + } + } + } + } + + /** + * Helper for binSeqOp, for data which can contain a mix of ordered and unordered features. + * + * For ordered features, a single bin is updated. + * For unordered features, bins correspond to subsets of categories; either the left or right bin + * for each subset is updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param bins possible bins for all features, indexed (numFeatures)(numBins) + * @param unorderedFeatures Set of indices of unordered features. + */ + private def mixedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + nodeIndex: Int, + bins: Array[Array[Bin]], + unorderedFeatures: Set[Int]): Unit = { + // Iterate over all features. + val numFeatures = treePoint.binnedFeatures.size + val nodeOffset = agg.getNodeOffset(nodeIndex) + var featureIndex = 0 + while (featureIndex < numFeatures) { + if (unorderedFeatures.contains(featureIndex)) { + // Unordered feature + val featureValue = treePoint.binnedFeatures(featureIndex) + val (leftNodeFeatureOffset, rightNodeFeatureOffset) = + agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + // Update the left or right bin for each split. + val numSplits = agg.numSplits(featureIndex) + var splitIndex = 0 + while (splitIndex < numSplits) { + if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { + agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label) + } else { + agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label) + } + splitIndex += 1 + } + } else { + // Ordered feature + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label) + } + featureIndex += 1 + } + } + + /** + * Helper for binSeqOp, for regression and for classification with only ordered features. + * + * For each feature, the sufficient statistics of one bin are updated. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). + * @param treePoint Data point being aggregated. + * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @return agg + */ + private def orderedBinSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint, + nodeIndex: Int): Unit = { + val label = treePoint.label + val nodeOffset = agg.getNodeOffset(nodeIndex) + // Iterate over all features. + val numFeatures = agg.numFeatures + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label) + featureIndex += 1 + } + } + /** * Returns an array of optimal splits for a group of nodes at a given level * @@ -481,8 +588,9 @@ object DecisionTree extends Serializable with Logging { * @param parentImpurities Impurities for all parent nodes for the current level * @param metadata Learning and dataset metadata * @param level Level of the tree - * @param splits possible splits for all features - * @param bins possible bins for all features, indexed as (numFeatures)(numBins) + * @param nodes Array of all nodes in the tree. Used for matching data points to nodes. + * @param splits possible splits for all features, indexed (numFeatures)(numSplits) + * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param numGroups total number of node groups at the current level. Default value is set to 1. * @param groupIndex index of the node group being processed. Default value is set to 0. * @return array of splits with best splits for all nodes at a given level. @@ -527,88 +635,22 @@ object DecisionTree extends Serializable with Logging { // numNodes: Number of nodes in this (level of tree, group), // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = (1 << level) / numGroups + val numNodes = Node.maxNodesInLevel(level) / numGroups logDebug("numNodes = " + numNodes) - // Find the number of features by looking at the first sample. - val numFeatures = metadata.numFeatures - logDebug("numFeatures = " + numFeatures) - - // numBins: Number of bins = 1 + number of possible splits - val numBins = bins(0).length - logDebug("numBins = " + numBins) - - val numClasses = metadata.numClasses - logDebug("numClasses = " + numClasses) - - val isMulticlass = metadata.isMulticlass - logDebug("isMulticlass = " + isMulticlass) - - val isMulticlassWithCategoricalFeatures = metadata.isMulticlassWithCategoricalFeatures - logDebug("isMultiClassWithCategoricalFeatures = " + isMulticlassWithCategoricalFeatures) + logDebug("numFeatures = " + metadata.numFeatures) + logDebug("numClasses = " + metadata.numClasses) + logDebug("isMulticlass = " + metadata.isMulticlass) + logDebug("isMulticlassWithCategoricalFeatures = " + + metadata.isMulticlassWithCategoricalFeatures) // shift when more than one group is used at deep tree level val groupShift = numNodes * groupIndex - /** - * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a node - * at the current level being trained; that node's index is returned. - * - * @return Leaf index if the data point reaches a leaf. - * Otherwise, last node reachable in tree matching this example. - */ - def predictNodeIndex(node: Node, binnedFeatures: Array[Int]): Int = { - if (node.isLeaf) { - node.id - } else { - val featureIndex = node.split.get.feature - val splitLeft = node.split.get.featureType match { - case Continuous => { - val binIndex = binnedFeatures(featureIndex) - val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold - // bin binIndex has range (bin.lowSplit.threshold, bin.highSplit.threshold] - // We do not need to check lowSplit since bins are separated by splits. - featureValueUpperBound <= node.split.get.threshold - } - case Categorical => { - val featureValue = if (metadata.isUnordered(featureIndex)) { - binnedFeatures(featureIndex) - } else { - val binIndex = binnedFeatures(featureIndex) - bins(featureIndex)(binIndex).category - } - node.split.get.categories.contains(featureValue) - } - case _ => throw new RuntimeException(s"predictNodeIndex failed for unknown reason.") - } - if (node.leftNode.isEmpty || node.rightNode.isEmpty) { - // Return index from next layer of nodes to train - if (splitLeft) { - node.id * 2 + 1 // left - } else { - node.id * 2 + 2 // right - } - } else { - if (splitLeft) { - predictNodeIndex(node.leftNode.get, binnedFeatures) - } else { - predictNodeIndex(node.rightNode.get, binnedFeatures) - } - } - } - } - - def nodeIndexToLevel(idx: Int): Int = { - if (idx == 0) { - 0 - } else { - math.floor(math.log(idx) / math.log(2)).toInt - } - } - - // Used for treePointToNodeIndex - val levelOffset = (1 << level) - 1 + // Used for treePointToNodeIndex to get an index for this (level, group). + // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level. + // - groupShift corrects for groups in this level before the current group. + val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift /** * Find the node index for the given example. @@ -619,661 +661,254 @@ object DecisionTree extends Serializable with Logging { if (level == 0) { 0 } else { - val globalNodeIndex = predictNodeIndex(nodes(0), treePoint.binnedFeatures) - // Get index for this (level, group). - globalNodeIndex - levelOffset - groupShift - } - } - - /** - * Increment aggregate in location for (node, feature, bin, label). - * - * @param treePoint Data point being aggregated. - * @param agg Array storing aggregate calculation, of size: - * numClasses * numBins * numFeatures * numNodes. - * Indexed by (node, feature, bin, label) where label is the least significant bit. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def updateBinForOrderedFeature( - treePoint: TreePoint, - agg: Array[Double], - nodeIndex: Int, - featureIndex: Int): Unit = { - // Update the left or right count for one bin. - val aggIndex = - numClasses * numBins * numFeatures * nodeIndex + - numClasses * numBins * featureIndex + - numClasses * treePoint.binnedFeatures(featureIndex) + - treePoint.label.toInt - agg(aggIndex) += 1 - } - - /** - * Increment aggregate in location for (nodeIndex, featureIndex, [bins], label), - * where [bins] ranges over all bins. - * Updates left or right side of aggregate depending on split. - * - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * @param treePoint Data point being aggregated. - * @param agg Indexed by (left/right, node, feature, bin, label) - * where label is the least significant bit. - * The left/right specifier is a 0/1 index indicating left/right child info. - * @param rightChildShift Offset for right side of agg. - */ - def updateBinForUnorderedFeature( - nodeIndex: Int, - featureIndex: Int, - treePoint: TreePoint, - agg: Array[Double], - rightChildShift: Int): Unit = { - val featureValue = treePoint.binnedFeatures(featureIndex) - // Update the left or right count for one bin. - val aggShift = - numClasses * numBins * numFeatures * nodeIndex + - numClasses * numBins * featureIndex + - treePoint.label.toInt - // Find all matching bins and increment their values - val featureCategories = metadata.featureArity(featureIndex) - val numCategoricalBins = (1 << featureCategories - 1) - 1 - var binIndex = 0 - while (binIndex < numCategoricalBins) { - val aggIndex = aggShift + binIndex * numClasses - if (bins(featureIndex)(binIndex).highSplit.categories.contains(featureValue)) { - agg(aggIndex) += 1 - } else { - agg(rightChildShift + aggIndex) += 1 - } - binIndex += 1 - } - } - - /** - * Helper for binSeqOp. - * - * @param agg Array storing aggregate calculation, of size: - * numClasses * numBins * numFeatures * numNodes. - * Indexed by (node, feature, bin, label) where label is the least significant bit. - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def binaryOrNotCategoricalBinSeqOp( - agg: Array[Double], - treePoint: TreePoint, - nodeIndex: Int): Unit = { - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) - featureIndex += 1 - } - } - - val rightChildShift = numClasses * numBins * numFeatures * numNodes - - /** - * Helper for binSeqOp. - * - * @param agg Array storing aggregate calculation. - * For ordered features, this is of size: - * numClasses * numBins * numFeatures * numNodes. - * For unordered features, this is of size: - * 2 * numClasses * numBins * numFeatures * numNodes. - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - */ - def multiclassWithCategoricalBinSeqOp( - agg: Array[Double], - treePoint: TreePoint, - nodeIndex: Int): Unit = { - val label = treePoint.label - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isUnordered(featureIndex)) { - updateBinForUnorderedFeature(nodeIndex, featureIndex, treePoint, agg, rightChildShift) - } else { - updateBinForOrderedFeature(treePoint, agg, nodeIndex, featureIndex) - } - featureIndex += 1 - } - } - - /** - * Performs a sequential aggregation over a partition for regression. - * For l nodes, k features, - * the count, sum, sum of squares of one of the p bins is incremented. - * - * @param agg Array storing aggregate calculation, updated by this function. - * Size: 3 * numBins * numFeatures * numNodes - * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * @return agg - */ - def regressionBinSeqOp(agg: Array[Double], treePoint: TreePoint, nodeIndex: Int): Unit = { - val label = treePoint.label - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Update count, sum, and sum^2 for one bin. - val binIndex = treePoint.binnedFeatures(featureIndex) - val aggIndex = - 3 * numBins * numFeatures * nodeIndex + - 3 * numBins * featureIndex + - 3 * binIndex - agg(aggIndex) += 1 - agg(aggIndex + 1) += label - agg(aggIndex + 2) += label * label - featureIndex += 1 + val globalNodeIndex = + predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures) + globalNodeIndex - globalNodeIndexOffset } } /** * Performs a sequential aggregation over a partition. - * For l nodes, k features, - * For classification: - * Either the left count or the right count of one of the bins is - * incremented based upon whether the feature is classified as 0 or 1. - * For regression: - * The count, sum, sum of squares of one of the bins is incremented. * - * @param agg Array storing aggregate calculation, updated by this function. - * Size for classification: - * numClasses * numBins * numFeatures * numNodes for ordered features, or - * 2 * numClasses * numBins * numFeatures * numNodes for unordered features. - * Size for regression: - * 3 * numBins * numFeatures * numNodes. + * Each data point contributes to one node. For each feature, + * the aggregate sufficient statistics are updated for the relevant bins. + * + * @param agg Array storing aggregate calculation, with a set of sufficient statistics for + * each (node, feature, bin). * @param treePoint Data point being aggregated. * @return agg */ - def binSeqOp(agg: Array[Double], treePoint: TreePoint): Array[Double] = { + def binSeqOp( + agg: DTStatsAggregator, + treePoint: TreePoint): DTStatsAggregator = { val nodeIndex = treePointToNodeIndex(treePoint) // If the example does not reach this level, then nodeIndex < 0. // If the example reaches this level but is handled in a different group, // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). if (nodeIndex >= 0 && nodeIndex < numNodes) { - if (metadata.isClassification) { - if (isMulticlassWithCategoricalFeatures) { - multiclassWithCategoricalBinSeqOp(agg, treePoint, nodeIndex) - } else { - binaryOrNotCategoricalBinSeqOp(agg, treePoint, nodeIndex) - } + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg, treePoint, nodeIndex) } else { - regressionBinSeqOp(agg, treePoint, nodeIndex) + mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures) } } agg } - // Calculate bin aggregate length for classification or regression. - val binAggregateLength = numNodes * getElementsPerNode(metadata, numBins) - logDebug("binAggregateLength = " + binAggregateLength) - - /** - * Combines the aggregates from partitions. - * @param agg1 Array containing aggregates from one or more partitions - * @param agg2 Array containing aggregates from one or more partitions - * @return Combined aggregate from agg1 and agg2 - */ - def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { - var index = 0 - val combinedAggregate = new Array[Double](binAggregateLength) - while (index < binAggregateLength) { - combinedAggregate(index) = agg1(index) + agg2(index) - index += 1 - } - combinedAggregate - } - // Calculate bin aggregates. timer.start("aggregation") - val binAggregates = { - input.treeAggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp, binCombOp) + val binAggregates: DTStatsAggregator = { + val initAgg = new DTStatsAggregator(metadata, numNodes) + input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp) } timer.stop("aggregation") - logDebug("binAggregates.length = " + binAggregates.length) - /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. - * @param leftNodeAgg left node aggregates for this (feature, split) - * @param rightNodeAgg right node aggregate for this (feature, split) - * @param topImpurity impurity of the parent node - * @return information gain and statistics for all splits - */ - def calculateGainForSplit( - leftNodeAgg: Array[Double], - rightNodeAgg: Array[Double], - topImpurity: Double): InformationGainStats = { - if (metadata.isClassification) { - val leftTotalCount = leftNodeAgg.sum - val rightTotalCount = rightNodeAgg.sum - - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val rootNodeCounts = new Array[Double](numClasses) - var classIndex = 0 - while (classIndex < numClasses) { - rootNodeCounts(classIndex) = leftNodeAgg(classIndex) + rightNodeAgg(classIndex) - classIndex += 1 - } - metadata.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount) - } - } - - val totalCount = leftTotalCount + rightTotalCount - if (totalCount == 0) { - // Return arbitrary prediction. - return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) - } - - // Sum of count for each label - val leftrightNodeAgg: Array[Double] = - leftNodeAgg.zip(rightNodeAgg).map { case (leftCount, rightCount) => - leftCount + rightCount - } - - def indexOfLargestArrayElement(array: Array[Double]): Int = { - val result = array.foldLeft(-1, Double.MinValue, 0) { - case ((maxIndex, maxValue, currentIndex), currentValue) => - if (currentValue > maxValue) { - (currentIndex, currentValue, currentIndex + 1) - } else { - (maxIndex, maxValue, currentIndex + 1) - } - } - if (result._1 < 0) { - throw new RuntimeException("DecisionTree internal error:" + - " calculateGainForSplit failed in indexOfLargestArrayElement") - } - result._1 - } - - val predict = indexOfLargestArrayElement(leftrightNodeAgg) - val prob = leftrightNodeAgg(predict) / totalCount - - val leftImpurity = if (leftTotalCount == 0) { - topImpurity - } else { - metadata.impurity.calculate(leftNodeAgg, leftTotalCount) - } - val rightImpurity = if (rightTotalCount == 0) { - topImpurity - } else { - metadata.impurity.calculate(rightNodeAgg, rightTotalCount) - } - - val leftWeight = leftTotalCount / totalCount - val rightWeight = rightTotalCount / totalCount - - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) - - } else { - // Regression - - val leftCount = leftNodeAgg(0) - val leftSum = leftNodeAgg(1) - val leftSumSquares = leftNodeAgg(2) + // Calculate best splits for all nodes at a given level + timer.start("chooseSplits") + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + // Iterating over all nodes at this level + var nodeIndex = 0 + while (nodeIndex < numNodes) { + val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex) + logDebug("node impurity = " + nodeImpurity) + bestSplits(nodeIndex) = + binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits) + logDebug("best split = " + bestSplits(nodeIndex)._1) + nodeIndex += 1 + } + timer.stop("chooseSplits") - val rightCount = rightNodeAgg(0) - val rightSum = rightNodeAgg(1) - val rightSumSquares = rightNodeAgg(2) + bestSplits + } - val impurity = { - if (level > 0) { - topImpurity - } else { - // Calculate impurity for root node. - val count = leftCount + rightCount - val sum = leftSum + rightSum - val sumSquares = leftSumSquares + rightSumSquares - metadata.impurity.calculate(count, sum, sumSquares) - } - } + /** + * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * @param leftImpurityCalculator left node aggregates for this (feature, split) + * @param rightImpurityCalculator right node aggregate for this (feature, split) + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + private def calculateGainForSplit( + leftImpurityCalculator: ImpurityCalculator, + rightImpurityCalculator: ImpurityCalculator, + topImpurity: Double, + level: Int, + metadata: DecisionTreeMetadata): InformationGainStats = { - if (leftCount == 0) { - return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, - rightSum / rightCount) - } - if (rightCount == 0) { - return new InformationGainStats(0, topImpurity, topImpurity, - Double.MinValue, leftSum / leftCount) - } + val leftCount = leftImpurityCalculator.count + val rightCount = rightImpurityCalculator.count - val leftImpurity = metadata.impurity.calculate(leftCount, leftSum, leftSumSquares) - val rightImpurity = metadata.impurity.calculate(rightCount, rightSum, rightSumSquares) + val totalCount = leftCount + rightCount + if (totalCount == 0) { + // Return arbitrary prediction. + return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0) + } - val leftWeight = leftCount.toDouble / (leftCount + rightCount) - val rightWeight = rightCount.toDouble / (leftCount + rightCount) + val parentNodeAgg = leftImpurityCalculator.copy + parentNodeAgg.add(rightImpurityCalculator) + // impurity of parent node + val impurity = if (level > 0) { + topImpurity + } else { + parentNodeAgg.calculate() + } - val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + val predict = parentNodeAgg.predict + val prob = parentNodeAgg.prob(predict) - val predict = (leftSum + rightSum) / (leftCount + rightCount) - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) - } - } + val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 + val rightImpurity = rightImpurityCalculator.calculate() - /** - * Extracts left and right split aggregates. - * @param binData Aggregate array slice from getBinDataForNode. - * For classification: - * For unordered features, this is leftChildData ++ rightChildData, - * each of which is indexed by (feature, split/bin, class), - * with class being the least significant bit. - * For ordered features, this is of size numClasses * numBins * numFeatures. - * For regression: - * This is of size 2 * numFeatures * numBins. - * @return (leftNodeAgg, rightNodeAgg) pair of arrays. - * For classification, each array is of size (numFeatures, (numBins - 1), numClasses). - * For regression, each array is of size (numFeatures, (numBins - 1), 3). - * - */ - def extractLeftRightNodeAggregates( - binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = { - - - /** - * The input binData is indexed as (feature, bin, class). - * This computes cumulative sums over splits. - * Each (feature, class) pair is handled separately. - * Note: numSplits = numBins - 1. - * @param leftNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 0, ..., numSplits - 2) is set to be - * the cumulative sum (from left) over binData for bins 0, ..., i. - * @param rightNodeAgg Each (feature, class) slice is an array over splits. - * Element i (i = 1, ..., numSplits - 1) is set to be - * the cumulative sum (from right) over binData for bins - * numBins - 1, ..., numBins - 1 - i. - */ - def findAggForOrderedFeatureClassification( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins - - var classIndex = 0 - while (classIndex < numClasses) { - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex) - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(classIndex) - = binData(shift + (numClasses * (numBins - 1)) + classIndex) - classIndex += 1 - } + val leftWeight = leftCount / totalCount.toDouble + val rightWeight = rightCount / totalCount.toDouble - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - var innerClassIndex = 0 - while (innerClassIndex < numClasses) { - leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex) - = binData(shift + numClasses * splitIndex + innerClassIndex) + - leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex) - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) = - binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex) - innerClassIndex += 1 - } - splitIndex += 1 - } - } + val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity - /** - * Reshape binData for this feature. - * Indexes binData as (feature, split, class) with class as the least significant bit. - * @param leftNodeAgg leftNodeAgg(featureIndex)(splitIndex)(classIndex) = aggregate value - */ - def findAggForUnorderedFeatureClassification( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - val rightChildShift = numClasses * numBins * numFeatures - var splitIndex = 0 - while (splitIndex < numBins - 1) { - var classIndex = 0 - while (classIndex < numClasses) { - // shift for this featureIndex - val shift = numClasses * featureIndex * numBins + splitIndex * numClasses - val leftBinValue = binData(shift + classIndex) - val rightBinValue = binData(rightChildShift + shift + classIndex) - leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue - rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue - classIndex += 1 - } - splitIndex += 1 - } - } + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob) + } - def findAggForRegression( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - featureIndex: Int) { - - // shift for this featureIndex - val shift = 3 * featureIndex * numBins - // left node aggregate for the lowest split - leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0) - leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1) - leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2) - - // right node aggregate for the highest split - rightNodeAgg(featureIndex)(numBins - 2)(0) = - binData(shift + (3 * (numBins - 1))) - rightNodeAgg(featureIndex)(numBins - 2)(1) = - binData(shift + (3 * (numBins - 1)) + 1) - rightNodeAgg(featureIndex)(numBins - 2)(2) = - binData(shift + (3 * (numBins - 1)) + 2) - - // Iterate over all splits. - var splitIndex = 1 - while (splitIndex < numBins - 1) { - var i = 0 // index for regression histograms - while (i < 3) { // count, sum, sum^2 - // calculating left node aggregate for a split as a sum of left node aggregate of a - // lower split and the left bin aggregate of a bin where the split is a high split - leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) + - leftNodeAgg(featureIndex)(splitIndex - 1)(i) - // calculating right node aggregate for a split as a sum of right node aggregate of a - // higher split and the right bin aggregate of a bin where the split is a low split - rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) = - binData(shift + (3 * (numBins - 1 - splitIndex) + i)) + - rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i) - i += 1 - } - splitIndex += 1 - } - } + /** + * Find the best split for a node. + * @param binAggregates Bin statistics. + * @param nodeIndex Index for node to split in this (level, group). + * @param nodeImpurity Impurity of the node (nodeIndex). + * @return tuple for best split: (Split, information gain) + */ + private def binsToBestSplit( + binAggregates: DTStatsAggregator, + nodeIndex: Int, + nodeImpurity: Double, + level: Int, + metadata: DecisionTreeMetadata, + splits: Array[Array[Split]]): (Split, InformationGainStats) = { - if (metadata.isClassification) { - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses) - var featureIndex = 0 - while (featureIndex < numFeatures) { - if (metadata.isUnordered(featureIndex)) { - findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } else { - findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex) - } - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - } else { - // Regression - // Initialize left and right split aggregates. - val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3) - // Iterate over all features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex) - featureIndex += 1 - } - (leftNodeAgg, rightNodeAgg) - } - } + logDebug("node impurity = " + nodeImpurity) - /** - * Calculates information gain for all nodes splits. - */ - def calculateGainsForAllNodeSplits( - leftNodeAgg: Array[Array[Array[Double]]], - rightNodeAgg: Array[Array[Array[Double]]], - nodeImpurity: Double): Array[Array[InformationGainStats]] = { - val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) - - var featureIndex = 0 - while (featureIndex < numFeatures) { - val numSplitsForFeature = getNumSplitsForFeature(featureIndex) + // For each (feature, split), calculate the gain, and select the best (feature, split). + Range(0, metadata.numFeatures).map { featureIndex => + val numSplits = metadata.numSplits(featureIndex) + if (metadata.isContinuous(featureIndex)) { + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) var splitIndex = 0 - while (splitIndex < numSplitsForFeature) { - gains(featureIndex)(splitIndex) = - calculateGainForSplit(leftNodeAgg(featureIndex)(splitIndex), - rightNodeAgg(featureIndex)(splitIndex), nodeImpurity) + while (splitIndex < numSplits) { + binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) splitIndex += 1 } - featureIndex += 1 - } - gains - } - - /** - * Get the number of splits for a feature. - */ - def getNumSplitsForFeature(featureIndex: Int): Int = { - if (metadata.isContinuous(featureIndex)) { - numBins - 1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { case splitIdx => + val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx) + val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) + rightChildStats.subtract(leftChildStats) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + (splitIdx, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) + } else if (metadata.isUnordered(featureIndex)) { + // Unordered categorical feature + val (leftChildOffset, rightChildOffset) = + binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) + val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { - // Categorical feature - val featureCategories = metadata.featureArity(featureIndex) - if (metadata.isUnordered(featureIndex)) { - (1 << featureCategories - 1) - 1 - } else { - featureCategories - } - } - } - - /** - * Find the best split for a node. - * @param binData Bin data slice for this node, given by getBinDataForNode. - * @param nodeImpurity impurity of the top node - * @return tuple of split and information gain - */ - def binsToBestSplit( - binData: Array[Double], - nodeImpurity: Double): (Split, InformationGainStats) = { - - logDebug("node impurity = " + nodeImpurity) - - // Extract left right node aggregates. - val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) - - // Calculate gains for all splits. - val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) - - val (bestFeatureIndex, bestSplitIndex, gainStats) = { - // Initialize with infeasible values. - var bestFeatureIndex = Int.MinValue - var bestSplitIndex = Int.MinValue - var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) - // Iterate over features. - var featureIndex = 0 - while (featureIndex < numFeatures) { - // Iterate over all splits. - var splitIndex = 0 - val numSplitsForFeature = getNumSplitsForFeature(featureIndex) - while (splitIndex < numSplitsForFeature) { - val gainStats = gains(featureIndex)(splitIndex) - if (gainStats.gain > bestGainStats.gain) { - bestGainStats = gainStats - bestFeatureIndex = featureIndex - bestSplitIndex = splitIndex + // Ordered categorical feature + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + val numBins = metadata.numBins(featureIndex) + + /* Each bin is one category (feature value). + * The bins are ordered based on centroidForCategories, and this ordering determines which + * splits are considered. (With K categories, we consider K - 1 possible splits.) + * + * centroidForCategories is a list: (category, centroid) + */ + val centroidForCategories = if (metadata.isMulticlass) { + // For categorical variables in multiclass classification, + // the bins are ordered by the impurity of their corresponding labels. + Range(0, numBins).map { case featureValue => + val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.calculate() + } else { + Double.MaxValue } - splitIndex += 1 + (featureValue, centroid) + } + } else { // regression or binary classification + // For categorical variables in regression and binary classification, + // the bins are ordered by the centroid of their corresponding labels. + Range(0, numBins).map { case featureValue => + val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val centroid = if (categoryStats.count != 0) { + categoryStats.predict + } else { + Double.MaxValue + } + (featureValue, centroid) } - featureIndex += 1 } - (bestFeatureIndex, bestSplitIndex, bestGainStats) - } - logDebug("best split = " + splits(bestFeatureIndex)(bestSplitIndex)) - logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + logDebug("Centroids for categorical variable: " + centroidForCategories.mkString(",")) - (splits(bestFeatureIndex)(bestSplitIndex), gainStats) - } + // bins sorted by centroids + val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2) - /** - * Get bin data for one node. - */ - def getBinDataForNode(node: Int): Array[Double] = { - if (metadata.isClassification) { - if (isMulticlassWithCategoricalFeatures) { - val shift = numClasses * node * numBins * numFeatures - val rightChildShift = numClasses * numBins * numFeatures * numNodes - val binsForNode = { - val leftChildData - = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - val rightChildData - = binAggregates.slice(rightChildShift + shift, - rightChildShift + shift + numClasses * numBins * numFeatures) - leftChildData ++ rightChildData - } - binsForNode - } else { - val shift = numClasses * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures) - binsForNode + logDebug("Sorted centroids for categorical variable = " + + categoriesSortedByCentroid.mkString(",")) + + // Cumulative sum (scanLeft) of bin statistics. + // Afterwards, binAggregates for a bin is the sum of aggregates for + // that bin + all preceding bins. + var splitIndex = 0 + while (splitIndex < numSplits) { + val currentCategory = categoriesSortedByCentroid(splitIndex)._1 + val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1 + binAggregates.mergeForNodeFeature(nodeFeatureOffset, nextCategory, currentCategory) + splitIndex += 1 } - } else { - // Regression - val shift = 3 * node * numBins * numFeatures - val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) - binsForNode + // lastCategory = index of bin with total aggregates for this (node, feature) + val lastCategory = categoriesSortedByCentroid.last._1 + // Find best split. + val (bestFeatureSplitIndex, bestFeatureGainStats) = + Range(0, numSplits).map { splitIndex => + val featureValue = categoriesSortedByCentroid(splitIndex)._1 + val leftChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue) + val rightChildStats = + binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) + rightChildStats.subtract(leftChildStats) + val gainStats = + calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata) + (splitIndex, gainStats) + }.maxBy(_._2.gain) + val categoriesForSplit = + categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) + val bestFeatureSplit = + new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit) + (bestFeatureSplit, bestFeatureGainStats) } - } - - // Calculate best splits for all nodes at a given level - timer.start("chooseSplits") - val bestSplits = new Array[(Split, InformationGainStats)](numNodes) - // Iterating over all nodes at this level - var node = 0 - while (node < numNodes) { - val nodeImpurityIndex = (1 << level) - 1 + node + groupShift - val binsForNode: Array[Double] = getBinDataForNode(node) - logDebug("nodeImpurityIndex = " + nodeImpurityIndex) - val parentNodeImpurity = parentImpurities(nodeImpurityIndex) - logDebug("parent node impurity = " + parentNodeImpurity) - bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) - node += 1 - } - timer.stop("chooseSplits") - - bestSplits + }.maxBy(_._2.gain) } /** * Get the number of values to be stored per node in the bin aggregates. - * - * @param numBins Number of bins = 1 + number of possible splits. */ - private def getElementsPerNode(metadata: DecisionTreeMetadata, numBins: Int): Int = { + private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = { + val totalBins = metadata.numBins.sum if (metadata.isClassification) { - if (metadata.isMulticlassWithCategoricalFeatures) { - 2 * metadata.numClasses * numBins * metadata.numFeatures - } else { - metadata.numClasses * numBins * metadata.numFeatures - } + metadata.numClasses * totalBins } else { - 3 * numBins * metadata.numFeatures + 3 * totalBins } } @@ -1284,6 +919,7 @@ object DecisionTree extends Serializable with Logging { * Continuous features: * For each feature, there are numBins - 1 possible splits representing the possible binary * decisions at each node in the tree. + * This finds locations (feature values) for splits using a subsample of the data. * * Categorical features: * For each feature, there is 1 bin per split. @@ -1292,7 +928,6 @@ object DecisionTree extends Serializable with Logging { * For multiclass classification with a low-arity feature * (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits), * the feature is split based on subsets of categories. - * There are (1 << maxFeatureValue - 1) - 1 splits. * (b) "ordered features" * For regression and binary classification, * and for multiclass classification with a high-arity feature, @@ -1302,7 +937,7 @@ object DecisionTree extends Serializable with Logging { * @param metadata Learning and dataset metadata * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] - * of size (numFeatures, numBins - 1). + * of size (numFeatures, numSplits). * Bins is an Array of [[org.apache.spark.mllib.tree.model.Bin]] * of size (numFeatures, numBins). */ @@ -1310,84 +945,80 @@ object DecisionTree extends Serializable with Logging { input: RDD[LabeledPoint], metadata: DecisionTreeMetadata): (Array[Array[Split]], Array[Array[Bin]]) = { - val count = input.count() + logDebug("isMulticlass = " + metadata.isMulticlass) - // Find the number of features by looking at the first sample - val numFeatures = input.take(1)(0).features.size - - val maxBins = metadata.maxBins - val numBins = if (maxBins <= count) maxBins else count.toInt - logDebug("numBins = " + numBins) - val isMulticlass = metadata.isMulticlass - logDebug("isMulticlass = " + isMulticlass) - - /* - * Ensure numBins is always greater than the categories. For multiclass classification, - * numBins should be greater than 2^(maxCategories - 1) - 1. - * It's a limitation of the current implementation but a reasonable trade-off since features - * with large number of categories get favored over continuous features. - * - * This needs to be checked here instead of in Strategy since numBins can be determined - * by the number of training examples. - * TODO: Allow this case, where we simply will know nothing about some categories. - */ - if (metadata.featureArity.size > 0) { - val maxCategoriesForFeatures = metadata.featureArity.maxBy(_._2)._2 - require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " + - "in categorical features") - } - - // Calculate the number of sample for approximate quantile calculation. - val requiredSamples = numBins*numBins - val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 - logDebug("fraction of data used for calculating quantiles = " + fraction) + val numFeatures = metadata.numFeatures - // sampled input for RDD calculation - val sampledInput = + // Sample the input only if there are continuous features. + val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous) + val sampledInput = if (hasContinuousFeatures) { + // Calculate the number of samples for approximate quantile calculation. + val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000) + val fraction = if (requiredSamples < metadata.numExamples) { + requiredSamples.toDouble / metadata.numExamples + } else { + 1.0 + } + logDebug("fraction of data used for calculating quantiles = " + fraction) input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect() - val numSamples = sampledInput.length - - val stride: Double = numSamples.toDouble / numBins - logDebug("stride = " + stride) + } else { + new Array[LabeledPoint](0) + } metadata.quantileStrategy match { case Sort => - val splits = Array.ofDim[Split](numFeatures, numBins - 1) - val bins = Array.ofDim[Bin](numFeatures, numBins) + val splits = new Array[Array[Split]](numFeatures) + val bins = new Array[Array[Bin]](numFeatures) // Find all splits. - // Iterate over all features. var featureIndex = 0 while (featureIndex < numFeatures) { - // Check whether the feature is continuous. - val isFeatureContinuous = metadata.isContinuous(featureIndex) - if (isFeatureContinuous) { + val numSplits = metadata.numSplits(featureIndex) + val numBins = metadata.numBins(featureIndex) + if (metadata.isContinuous(featureIndex)) { + val numSamples = sampledInput.length + splits(featureIndex) = new Array[Split](numSplits) + bins(featureIndex) = new Array[Bin](numBins) val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted - val stride: Double = numSamples.toDouble / numBins + val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex) logDebug("stride = " + stride) - for (index <- 0 until numBins - 1) { - val sampleIndex = index * stride.toInt + for (splitIndex <- 0 until numSplits) { + val sampleIndex = splitIndex * stride.toInt // Set threshold halfway in between 2 samples. val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0 - val split = new Split(featureIndex, threshold, Continuous, List()) - splits(featureIndex)(index) = split + splits(featureIndex)(splitIndex) = + new Split(featureIndex, threshold, Continuous, List()) } - } else { // Categorical feature - val featureCategories = metadata.featureArity(featureIndex) - - // Use different bin/split calculation strategy for categorical features in multiclass - // classification that satisfy the space constraint. + bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), + splits(featureIndex)(0), Continuous, Double.MinValue) + for (splitIndex <- 1 until numSplits) { + bins(featureIndex)(splitIndex) = + new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex), + Continuous, Double.MinValue) + } + bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1), + new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) + } else { + // Categorical feature + val featureArity = metadata.featureArity(featureIndex) if (metadata.isUnordered(featureIndex)) { - // 2^(maxFeatureValue- 1) - 1 combinations - var index = 0 - while (index < (1 << featureCategories - 1) - 1) { - val categories: List[Double] - = extractMultiClassCategories(index + 1, featureCategories) - splits(featureIndex)(index) - = new Split(featureIndex, Double.MinValue, Categorical, categories) - bins(featureIndex)(index) = { - if (index == 0) { + // TODO: The second half of the bins are unused. Actually, we could just use + // splits and not build bins for unordered features. That should be part of + // a later PR since it will require changing other code (using splits instead + // of bins in a few places). + // Unordered features + // 2^(maxFeatureValue - 1) - 1 combinations + splits(featureIndex) = new Array[Split](numSplits) + bins(featureIndex) = new Array[Bin](numBins) + var splitIndex = 0 + while (splitIndex < numSplits) { + val categories: List[Double] = + extractMultiClassCategories(splitIndex + 1, featureArity) + splits(featureIndex)(splitIndex) = + new Split(featureIndex, Double.MinValue, Categorical, categories) + bins(featureIndex)(splitIndex) = { + if (splitIndex == 0) { new Bin( new DummyCategoricalSplit(featureIndex, Categorical), splits(featureIndex)(0), @@ -1395,96 +1026,24 @@ object DecisionTree extends Serializable with Logging { Double.MinValue) } else { new Bin( - splits(featureIndex)(index - 1), - splits(featureIndex)(index), + splits(featureIndex)(splitIndex - 1), + splits(featureIndex)(splitIndex), Categorical, Double.MinValue) } } - index += 1 - } - } else { // ordered feature - /* For a given categorical feature, use a subsample of the data - * to choose how to arrange possible splits. - * This examines each category and computes a centroid. - * These centroids are later used to sort the possible splits. - * centroidForCategories is a mapping: category (for the given feature) --> centroid - */ - val centroidForCategories = { - if (isMulticlass) { - // For categorical variables in multiclass classification, - // each bin is a category. The bins are sorted and they - // are ordered by calculating the impurity of their corresponding labels. - sampledInput.map(lp => (lp.features(featureIndex), lp.label)) - .groupBy(_._1) - .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble)) - .map(x => (x._1, x._2.values.toArray)) - .map(x => (x._1, metadata.impurity.calculate(x._2, x._2.sum))) - } else { // regression or binary classification - // For categorical variables in regression and binary classification, - // each bin is a category. The bins are sorted and they - // are ordered by calculating the centroid of their corresponding labels. - sampledInput.map(lp => (lp.features(featureIndex), lp.label)) - .groupBy(_._1) - .mapValues(x => x.map(_._2).sum / x.map(_._1).length) - } - } - - logDebug("centroid for categories = " + centroidForCategories.mkString(",")) - - // Check for missing categorical variables and putting them last in the sorted list. - val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() - for (i <- 0 until featureCategories) { - if (centroidForCategories.contains(i)) { - fullCentroidForCategories(i) = centroidForCategories(i) - } else { - fullCentroidForCategories(i) = Double.MaxValue - } - } - - // bins sorted by centroids - val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) - - logDebug("centroid for categorical variable = " + categoriesSortedByCentroid) - - var categoriesForSplit = List[Double]() - categoriesSortedByCentroid.iterator.zipWithIndex.foreach { - case ((key, value), index) => - categoriesForSplit = key :: categoriesForSplit - splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, - Categorical, categoriesForSplit) - bins(featureIndex)(index) = { - if (index == 0) { - new Bin(new DummyCategoricalSplit(featureIndex, Categorical), - splits(featureIndex)(0), Categorical, key) - } else { - new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Categorical, key) - } - } + splitIndex += 1 } + } else { + // Ordered features + // Bins correspond to feature values, so we do not need to compute splits or bins + // beforehand. Splits are constructed as needed during training. + splits(featureIndex) = new Array[Split](0) + bins(featureIndex) = new Array[Bin](0) } } featureIndex += 1 } - - // Find all bins. - featureIndex = 0 - while (featureIndex < numFeatures) { - val isFeatureContinuous = metadata.isContinuous(featureIndex) - if (isFeatureContinuous) { // Bins for categorical variables are already assigned. - bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), - splits(featureIndex)(0), Continuous, Double.MinValue) - for (index <- 1 until numBins - 1) { - val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), - Continuous, Double.MinValue) - bins(featureIndex)(index) = bin - } - bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), - new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) - } - featureIndex += 1 - } (splits, bins) case MinMax => throw new UnsupportedOperationException("minmax not supported yet.") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala new file mode 100644 index 0000000000000..866d85a79bea1 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import org.apache.spark.mllib.tree.impurity._ + +/** + * DecisionTree statistics aggregator. + * This holds a flat array of statistics for a set of (nodes, features, bins) + * and helps with indexing. + */ +private[tree] class DTStatsAggregator( + val metadata: DecisionTreeMetadata, + val numNodes: Int) extends Serializable { + + /** + * [[ImpurityAggregator]] instance specifying the impurity type. + */ + val impurityAggregator: ImpurityAggregator = metadata.impurity match { + case Gini => new GiniAggregator(metadata.numClasses) + case Entropy => new EntropyAggregator(metadata.numClasses) + case Variance => new VarianceAggregator() + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + + /** + * Number of elements (Double values) used for the sufficient statistics of each bin. + */ + val statsSize: Int = impurityAggregator.statsSize + + val numFeatures: Int = metadata.numFeatures + + /** + * Number of bins for each feature. This is indexed by the feature index. + */ + val numBins: Array[Int] = metadata.numBins + + /** + * Number of splits for the given feature. + */ + def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex) + + /** + * Indicator for each feature of whether that feature is an unordered feature. + * TODO: Is Array[Boolean] any faster? + */ + def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + */ + private val featureOffsets: Array[Int] = { + def featureOffsetsCalc(total: Int, featureIndex: Int): Int = { + if (isUnordered(featureIndex)) { + total + 2 * numBins(featureIndex) + } else { + total + numBins(featureIndex) + } + } + Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray + } + + /** + * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. + */ + private val nodeStride: Int = featureOffsets.last + + /** + * Total number of elements stored in this aggregator. + */ + val allStatsSize: Int = numNodes * nodeStride + + /** + * Flat array of elements. + * Index for start of stats for a (node, feature, bin) is: + * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex)) + * and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex)) + */ + val allStats: Array[Double] = new Array[Double](allStatsSize) + + /** + * Get an [[ImpurityCalculator]] for a given (node, feature, bin). + * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getNodeFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightNodeFeatureOffsets]]. + */ + def getImpurityCalculator(nodeFeatureOffset: Int, binIndex: Int): ImpurityCalculator = { + impurityAggregator.getCalculator(allStats, nodeFeatureOffset + binIndex * statsSize) + } + + /** + * Update the stats for a given (node, feature, bin) for ordered features, using the given label. + */ + def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { + val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label) + } + + /** + * Pre-compute node offset for use with [[nodeUpdate]]. + */ + def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride + + /** + * Faster version of [[update]]. + * Update the stats for a given (node, feature, bin) for ordered features, using the given label. + * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. + */ + def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { + val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label) + } + + /** + * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * For ordered features only. + */ + def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { + require(!isUnordered(featureIndex), + s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" + + s" for unordered feature $featureIndex.") + nodeIndex * nodeStride + featureOffsets(featureIndex) + } + + /** + * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * For unordered features only. + */ + def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { + require(isUnordered(featureIndex), + s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + + s" but was called for ordered feature $featureIndex.") + val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) + (baseOffset, baseOffset + numBins(featureIndex) * statsSize) + } + + /** + * Faster version of [[update]]. + * Update the stats for a given (node, feature, bin), using the given label. + * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getNodeFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightNodeFeatureOffsets]]. + */ + def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = { + impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label) + } + + /** + * For a given (node, feature), merge the stats for two bins. + * @param nodeFeatureOffset For ordered features, this is a pre-computed (node, feature) offset + * from [[getNodeFeatureOffset]]. + * For unordered features, this is a pre-computed + * (node, feature, left/right child) offset from + * [[getLeftRightNodeFeatureOffsets]]. + * @param binIndex The other bin is merged into this bin. + * @param otherBinIndex This bin is not modified. + */ + def mergeForNodeFeature(nodeFeatureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { + impurityAggregator.merge(allStats, nodeFeatureOffset + binIndex * statsSize, + nodeFeatureOffset + otherBinIndex * statsSize) + } + + /** + * Merge this aggregator with another, and returns this aggregator. + * This method modifies this aggregator in-place. + */ + def merge(other: DTStatsAggregator): DTStatsAggregator = { + require(allStatsSize == other.allStatsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." + + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") + var i = 0 + // TODO: Test BLAS.axpy + while (i < allStatsSize) { + allStats(i) += other.allStats(i) + i += 1 + } + this + } + +} + +private[tree] object DTStatsAggregator extends Serializable { + + /** + * Combines two aggregates (modifying the first) and returns the combination. + */ + def binCombOp( + agg1: DTStatsAggregator, + agg2: DTStatsAggregator): DTStatsAggregator = { + agg1.merge(agg2) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index d9eda354dc986..e95add7558bcf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -26,14 +26,15 @@ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impurity.Impurity import org.apache.spark.rdd.RDD - /** * Learning and dataset metadata for DecisionTree. * * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. * For regression: fixed at 0 (no meaning). + * @param maxBins Maximum number of bins, for all features. * @param featureArity Map: categorical feature index --> arity. * I.e., the feature takes values in {0, ..., arity - 1}. + * @param numBins Number of bins for each feature. */ private[tree] class DecisionTreeMetadata( val numFeatures: Int, @@ -42,6 +43,7 @@ private[tree] class DecisionTreeMetadata( val maxBins: Int, val featureArity: Map[Int, Int], val unorderedFeatures: Set[Int], + val numBins: Array[Int], val impurity: Impurity, val quantileStrategy: QuantileStrategy) extends Serializable { @@ -57,10 +59,26 @@ private[tree] class DecisionTreeMetadata( def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + /** + * Number of splits for the given feature. + * For unordered features, there are 2 bins per split. + * For ordered features, there is 1 more bin than split. + */ + def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { + numBins(featureIndex) >> 1 + } else { + numBins(featureIndex) - 1 + } + } private[tree] object DecisionTreeMetadata { + /** + * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. + * This computes which categorical features will be ordered vs. unordered, + * as well as the number of splits and bins for each feature. + */ def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { val numFeatures = input.take(1)(0).features.size @@ -70,32 +88,55 @@ private[tree] object DecisionTreeMetadata { case Regression => 0 } - val maxBins = math.min(strategy.maxBins, numExamples).toInt - val log2MaxBinsp1 = math.log(maxBins + 1) / math.log(2.0) + val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + + // We check the number of bins here against maxPossibleBins. + // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified + // based on the number of training examples. + if (strategy.categoricalFeaturesInfo.nonEmpty) { + val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + require(maxCategoriesPerFeature <= maxPossibleBins, + s"DecisionTree requires maxBins (= $maxPossibleBins) >= max categories " + + s"in categorical features (= $maxCategoriesPerFeature)") + } val unorderedFeatures = new mutable.HashSet[Int]() + val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) if (numClasses > 2) { - strategy.categoricalFeaturesInfo.foreach { case (f, k) => - if (k - 1 < log2MaxBinsp1) { - // Note: The above check is equivalent to checking: - // numUnorderedBins = (1 << k - 1) - 1 < maxBins - unorderedFeatures.add(f) + // Multiclass classification + val maxCategoriesForUnorderedFeature = + ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) } else { - // TODO: Allow this case, where we simply will know nothing about some categories? - require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + - s"in categorical features (>= $k)") + numBins(featureIndex) = numCategories } } } else { - strategy.categoricalFeaturesInfo.foreach { case (f, k) => - require(k < maxBins, s"maxBins (= $maxBins) should be greater than max categories " + - s"in categorical features (>= $k)") + // Binary classification or regression + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + numBins(featureIndex) = numCategories } } - new DecisionTreeMetadata(numFeatures, numExamples, numClasses, maxBins, - strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy) } + /** + * Given the arity of a categorical feature (arity = number of categories), + * return the number of bins for the feature if it is to be treated as an unordered feature. + * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; + * there are math.pow(2, arity - 1) - 1 such splits. + * Each split has 2 corresponding bins. + */ + def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1) + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala index 170e43e222083..35e361ae309cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala @@ -48,54 +48,63 @@ private[tree] object TreePoint { * binning feature values in preparation for DecisionTree training. * @param input Input dataset. * @param bins Bins for features, of size (numFeatures, numBins). - * @param metadata Learning and dataset metadata + * @param metadata Learning and dataset metadata * @return TreePoint dataset representation */ def convertToTreeRDD( input: RDD[LabeledPoint], bins: Array[Array[Bin]], metadata: DecisionTreeMetadata): RDD[TreePoint] = { + // Construct arrays for featureArity and isUnordered for efficiency in the inner loop. + val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) + val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures) + var featureIndex = 0 + while (featureIndex < metadata.numFeatures) { + featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) + isUnordered(featureIndex) = metadata.isUnordered(featureIndex) + featureIndex += 1 + } input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, metadata) + TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered) } } /** * Convert one LabeledPoint into its TreePoint representation. * @param bins Bins for features, of size (numFeatures, numBins). + * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories + * for categorical features. + * @param isUnordered Array index by feature, with value true for unordered categorical features. */ private def labeledPointToTreePoint( labeledPoint: LabeledPoint, bins: Array[Array[Bin]], - metadata: DecisionTreeMetadata): TreePoint = { - + featureArity: Array[Int], + isUnordered: Array[Boolean]): TreePoint = { val numFeatures = labeledPoint.features.size - val numBins = bins(0).size val arr = new Array[Int](numFeatures) var featureIndex = 0 while (featureIndex < numFeatures) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, metadata.isContinuous(featureIndex), - metadata.isUnordered(featureIndex), bins, metadata.featureArity) + arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), + isUnordered(featureIndex), bins) featureIndex += 1 } - new TreePoint(labeledPoint.label, arr) } /** * Find bin for one (labeledPoint, feature). * + * @param featureArity 0 for continuous features; number of categories for categorical features. * @param isUnorderedFeature (only applies if feature is categorical) * @param bins Bins for features, of size (numFeatures, numBins). - * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity */ private def findBin( featureIndex: Int, labeledPoint: LabeledPoint, - isFeatureContinuous: Boolean, + featureArity: Int, isUnorderedFeature: Boolean, - bins: Array[Array[Bin]], - categoricalFeaturesInfo: Map[Int, Int]): Int = { + bins: Array[Array[Bin]]): Int = { /** * Binary search helper method for continuous feature. @@ -121,44 +130,7 @@ private[tree] object TreePoint { -1 } - /** - * Sequential search helper method to find bin for categorical feature in multiclass - * classification. The category is returned since each category can belong to multiple - * splits. The actual left/right child allocation per split is performed in the - * sequential phase of the bin aggregate operation. - */ - def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = { - labeledPoint.features(featureIndex).toInt - } - - /** - * Sequential search helper method to find bin for categorical feature - * (for classification and regression). - */ - def sequentialBinSearchForOrderedCategoricalFeature(): Int = { - val featureCategories = categoricalFeaturesInfo(featureIndex) - val featureValue = labeledPoint.features(featureIndex) - var binIndex = 0 - while (binIndex < featureCategories) { - val bin = bins(featureIndex)(binIndex) - val categories = bin.highSplit.categories - if (categories.contains(featureValue)) { - return binIndex - } - binIndex += 1 - } - if (featureValue < 0 || featureValue >= featureCategories) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureCategories - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - -1 - } - - if (isFeatureContinuous) { + if (featureArity == 0) { // Perform binary search for finding bin for continuous features. val binIndex = binarySearchForBins() if (binIndex == -1) { @@ -168,18 +140,17 @@ private[tree] object TreePoint { } binIndex } else { - // Perform sequential search to find bin for categorical features. - val binIndex = if (isUnorderedFeature) { - sequentialBinSearchForUnorderedCategoricalFeatureInClassification() - } else { - sequentialBinSearchForOrderedCategoricalFeature() - } - if (binIndex == -1) { - throw new RuntimeException("No bin was found for categorical feature." + - " This error can occur when given invalid data values (such as NaN)." + - s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") + // Categorical feature bins are indexed by feature values. + val featureValue = labeledPoint.features(featureIndex) + if (featureValue < 0 || featureValue >= featureArity) { + throw new IllegalArgumentException( + s"DecisionTree given invalid data:" + + s" Feature $featureIndex is categorical with values in" + + s" {0,...,${featureArity - 1}," + + s" but a data point gives it value $featureValue.\n" + + " Bad data point: " + labeledPoint.toString) } - binIndex + featureValue.toInt } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 96d2471e1f88c..1c8afc2d0f4bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -74,3 +74,87 @@ object Entropy extends Impurity { def instance = this } + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param numClasses Number of classes for label. + */ +private[tree] class EntropyAggregator(numClasses: Int) + extends ImpurityAggregator(numClasses) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"EntropyAggregator given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + allStats(offset + label.toInt) += 1 + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { + new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) + } + +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[EntropyAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: EntropyCalculator = new EntropyCalculator(stats.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = Entropy.calculate(stats, stats.sum) + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats.sum.toLong + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = if (count == 0) { + 0 + } else { + indexOfLargestArrayElement(stats) + } + + /** + * Probability of the label given by [[predict]]. + */ + override def prob(label: Double): Double = { + val lbl = label.toInt + require(lbl < stats.length, + s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + val cnt = count + if (cnt == 0) { + 0 + } else { + stats(lbl) / cnt + } + } + + override def toString: String = s"EntropyCalculator(stats = [${stats.mkString(", ")}])" + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index d586f449048bb..5cfdf345d163c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -70,3 +70,87 @@ object Gini extends Impurity { def instance = this } + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param numClasses Number of classes for label. + */ +private[tree] class GiniAggregator(numClasses: Int) + extends ImpurityAggregator(numClasses) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + if (label >= statsSize) { + throw new IllegalArgumentException(s"GiniAggregator given label $label" + + s" but requires label < numClasses (= $statsSize).") + } + allStats(offset + label.toInt) += 1 + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { + new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) + } + +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[GiniAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: GiniCalculator = new GiniCalculator(stats.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = Gini.calculate(stats, stats.sum) + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats.sum.toLong + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = if (count == 0) { + 0 + } else { + indexOfLargestArrayElement(stats) + } + + /** + * Probability of the label given by [[predict]]. + */ + override def prob(label: Double): Double = { + val lbl = label.toInt + require(lbl < stats.length, + s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}") + val cnt = count + if (cnt == 0) { + 0 + } else { + stats(lbl) / cnt + } + } + + override def toString: String = s"GiniCalculator(stats = [${stats.mkString(", ")}])" + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 92b0c7b4a6fbc..5a047d6cb5480 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -22,6 +22,9 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} /** * :: Experimental :: * Trait for calculating information gain. + * This trait is used for + * (a) setting the impurity parameter in [[org.apache.spark.mllib.tree.configuration.Strategy]] + * (b) calculating impurity values from sufficient statistics. */ @Experimental trait Impurity extends Serializable { @@ -47,3 +50,127 @@ trait Impurity extends Serializable { @DeveloperApi def calculate(count: Double, sum: Double, sumSquares: Double): Double } + +/** + * Interface for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + * @param statsSize Length of the vector of sufficient statistics for one bin. + */ +private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Serializable { + + /** + * Merge the stats from one bin into another. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for (node, feature, bin) which is modified by the merge. + * @param otherOffset Start index of stats for (node, feature, other bin) which is not modified. + */ + def merge(allStats: Array[Double], offset: Int, otherOffset: Int): Unit = { + var i = 0 + while (i < statsSize) { + allStats(offset + i) += allStats(otherOffset + i) + i += 1 + } + } + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator + +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[ImpurityAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[tree] abstract class ImpurityCalculator(val stats: Array[Double]) { + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: ImpurityCalculator + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double + + /** + * Add the stats from another calculator into this one, modifying and returning this calculator. + */ + def add(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.size == other.stats.size, + s"Two ImpurityCalculator instances cannot be added with different counts sizes." + + s" Sizes are ${stats.size} and ${other.stats.size}.") + var i = 0 + while (i < other.stats.size) { + stats(i) += other.stats(i) + i += 1 + } + this + } + + /** + * Subtract the stats from another calculator from this one, modifying and returning this + * calculator. + */ + def subtract(other: ImpurityCalculator): ImpurityCalculator = { + require(stats.size == other.stats.size, + s"Two ImpurityCalculator instances cannot be subtracted with different counts sizes." + + s" Sizes are ${stats.size} and ${other.stats.size}.") + var i = 0 + while (i < other.stats.size) { + stats(i) -= other.stats(i) + i += 1 + } + this + } + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double + + /** + * Probability of the label given by [[predict]], or -1 if no probability is available. + */ + def prob(label: Double): Double = -1 + + /** + * Return the index of the largest array element. + * Fails if the array is empty. + */ + protected def indexOfLargestArrayElement(array: Array[Double]): Int = { + val result = array.foldLeft(-1, Double.MinValue, 0) { + case ((maxIndex, maxValue, currentIndex), currentValue) => + if (currentValue > maxValue) { + (currentIndex, currentValue, currentIndex + 1) + } else { + (maxIndex, maxValue, currentIndex + 1) + } + } + if (result._1 < 0) { + throw new RuntimeException("ImpurityCalculator internal error:" + + " indexOfLargestArrayElement failed") + } + result._1 + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index f7d99a40eb380..e9ccecb1b8067 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -61,3 +61,75 @@ object Variance extends Impurity { def instance = this } + +/** + * Class for updating views of a vector of sufficient statistics, + * in order to compute impurity from a sample. + * Note: Instances of this class do not hold the data; they operate on views of the data. + */ +private[tree] class VarianceAggregator() + extends ImpurityAggregator(statsSize = 3) with Serializable { + + /** + * Update stats for one (node, feature, bin) with the given label. + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + allStats(offset) += 1 + allStats(offset + 1) += label + allStats(offset + 2) += label * label + } + + /** + * Get an [[ImpurityCalculator]] for a (node, feature, bin). + * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. + * @param offset Start index of stats for this (node, feature, bin). + */ + def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { + new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) + } + +} + +/** + * Stores statistics for one (node, feature, bin) for calculating impurity. + * Unlike [[GiniAggregator]], this class stores its own data and is for a specific + * (node, feature, bin). + * @param stats Array of sufficient statistics for a (node, feature, bin). + */ +private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { + + require(stats.size == 3, + s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + + s" but was given array of length ${stats.size}.") + + /** + * Make a deep copy of this [[ImpurityCalculator]]. + */ + def copy: VarianceCalculator = new VarianceCalculator(stats.clone()) + + /** + * Calculate the impurity from the stored sufficient statistics. + */ + def calculate(): Double = Variance.calculate(stats(0), stats(1), stats(2)) + + /** + * Number of data points accounted for in the sufficient statistics. + */ + def count: Long = stats(0).toLong + + /** + * Prediction which should be made based on the sufficient statistics. + */ + def predict: Double = if (count == 0) { + 0 + } else { + stats(1) / count + } + + override def toString: String = { + s"VarianceAggregator(cnt = ${stats(0)}, sum = ${stats(1)}, sum2 = ${stats(2)})" + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala index af35d88f713e5..0cad473782af1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.mllib.tree.configuration.FeatureType._ /** - * Used for "binning" the features bins for faster best split calculation. + * Used for "binning" the feature values for faster best split calculation. * * For a continuous feature, the bin is determined by a low and a high split, * where an example with featureValue falls into the bin s.t. @@ -30,13 +30,16 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._ * bins, splits, and feature values. The bin is determined by category/feature value. * However, the bins are not necessarily ordered by feature value; * they are ordered using impurity. + * * For unordered categorical features, there is a 1-1 correspondence between bins, splits, * where bins and splits correspond to subsets of feature values (in highSplit.categories). + * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all + * partitionings of categories into 2 disjoint, non-empty sets. * * @param lowSplit signifying the lower threshold for the continuous feature to be * accepted in the bin * @param highSplit signifying the upper threshold for the continuous feature to be - * accepted in the bin + * accepted in the bin * @param featureType type of feature -- categorical or continuous * @param category categorical label value accepted in the bin for ordered features */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 0eee6262781c1..5b8a4cbed2306 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -24,8 +24,13 @@ import org.apache.spark.mllib.linalg.Vector /** * :: DeveloperApi :: - * Node in a decision tree - * @param id integer node id + * Node in a decision tree. + * + * About node indexing: + * Nodes are indexed from 1. Node 1 is the root; nodes 2, 3 are the left, right children. + * Node index 0 is not used. + * + * @param id integer node id, from 1 * @param predict predicted value at the node * @param isLeaf whether the leaf is a node * @param split split to calculate left and right nodes @@ -51,17 +56,13 @@ class Node ( * @param nodes array of nodes */ def build(nodes: Array[Node]): Unit = { - - logDebug("building node " + id + " at level " + - (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("building node " + id + " at level " + Node.indexToLevel(id)) logDebug("id = " + id + ", split = " + split) logDebug("stats = " + stats) logDebug("predict = " + predict) if (!isLeaf) { - val leftNodeIndex = id * 2 + 1 - val rightNodeIndex = id * 2 + 2 - leftNode = Some(nodes(leftNodeIndex)) - rightNode = Some(nodes(rightNodeIndex)) + leftNode = Some(nodes(Node.leftChildIndex(id))) + rightNode = Some(nodes(Node.rightChildIndex(id))) leftNode.get.build(nodes) rightNode.get.build(nodes) } @@ -96,24 +97,20 @@ class Node ( * Get the number of nodes in tree below this node, including leaf nodes. * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2. */ - private[tree] def numDescendants: Int = { - if (isLeaf) { - 0 - } else { - 2 + leftNode.get.numDescendants + rightNode.get.numDescendants - } + private[tree] def numDescendants: Int = if (isLeaf) { + 0 + } else { + 2 + leftNode.get.numDescendants + rightNode.get.numDescendants } /** * Get depth of tree from this node. * E.g.: Depth 0 means this is a leaf node. */ - private[tree] def subtreeDepth: Int = { - if (isLeaf) { - 0 - } else { - 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) - } + private[tree] def subtreeDepth: Int = if (isLeaf) { + 0 + } else { + 1 + math.max(leftNode.get.subtreeDepth, rightNode.get.subtreeDepth) } /** @@ -148,3 +145,49 @@ class Node ( } } + +private[tree] object Node { + + /** + * Return the index of the left child of this node. + */ + def leftChildIndex(nodeIndex: Int): Int = nodeIndex << 1 + + /** + * Return the index of the right child of this node. + */ + def rightChildIndex(nodeIndex: Int): Int = (nodeIndex << 1) + 1 + + /** + * Get the parent index of the given node, or 0 if it is the root. + */ + def parentIndex(nodeIndex: Int): Int = nodeIndex >> 1 + + /** + * Return the level of a tree which the given node is in. + */ + def indexToLevel(nodeIndex: Int): Int = if (nodeIndex == 0) { + throw new IllegalArgumentException(s"0 is not a valid node index.") + } else { + java.lang.Integer.numberOfTrailingZeros(java.lang.Integer.highestOneBit(nodeIndex)) + } + + /** + * Returns true if this is a left child. + * Note: Returns false for the root. + */ + def isLeftChild(nodeIndex: Int): Boolean = nodeIndex > 1 && nodeIndex % 2 == 0 + + /** + * Return the maximum number of nodes which can be in the given level of the tree. + * @param level Level of tree (0 = root). + */ + def maxNodesInLevel(level: Int): Int = 1 << level + + /** + * Return the index of the first node in the given level. + * @param level Level of tree (0 = root). + */ + def startIndexInLevel(level: Int): Int = 1 << level + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2f36fd907772c..8e556c917b2e7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -21,15 +21,16 @@ import scala.collection.JavaConverters._ import org.scalatest.FunSuite +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node} -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.LocalSparkContext -import org.apache.spark.mllib.regression.LabeledPoint + class DecisionTreeSuite extends FunSuite with LocalSparkContext { @@ -59,12 +60,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") } - test("split and bin calculation") { + test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) val strategy = new Strategy(Classification, Gini, 3, 2, 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(bins.length === 2) @@ -72,7 +74,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) } - test("split and bin calculation for categorical variables") { + test("Binary classification with binary (ordered) categorical features:" + + " split and bin calculation") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -83,77 +86,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) assert(splits.length === 2) assert(bins.length === 2) - assert(splits(0).length === 99) - assert(bins(0).length === 100) - - // Check splits. - - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(1.0)) - assert(splits(0)(1).categories.contains(0.0)) - - assert(splits(0)(2) === null) - - assert(splits(1)(0).feature === 1) - assert(splits(1)(0).threshold === Double.MinValue) - assert(splits(1)(0).featureType === Categorical) - assert(splits(1)(0).categories.length === 1) - assert(splits(1)(0).categories.contains(0.0)) - - assert(splits(1)(1).feature === 1) - assert(splits(1)(1).threshold === Double.MinValue) - assert(splits(1)(1).featureType === Categorical) - assert(splits(1)(1).categories.length === 2) - assert(splits(1)(1).categories.contains(1.0)) - assert(splits(1)(1).categories.contains(0.0)) - - assert(splits(1)(2) === null) - - // Check bins. - - assert(bins(0)(0).category === 1.0) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(1.0)) - - assert(bins(0)(1).category === 0.0) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.contains(0.0)) - - assert(bins(0)(2) === null) - - assert(bins(1)(0).category === 0.0) - assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0)) - - assert(bins(1)(1).category === 1.0) - assert(bins(1)(1).lowSplit.categories.length === 1) - assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 2) - assert(bins(1)(1).highSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.contains(1.0)) - - assert(bins(1)(2) === null) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) } - test("split and bin calculations for categorical variables with no sample for one category") { + test("Binary classification with 3-ary (ordered) categorical features," + + " with no samples for one category") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -164,104 +110,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - - // Check splits. - - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(1.0)) - assert(splits(0)(1).categories.contains(0.0)) - - assert(splits(0)(2).feature === 0) - assert(splits(0)(2).threshold === Double.MinValue) - assert(splits(0)(2).featureType === Categorical) - assert(splits(0)(2).categories.length === 3) - assert(splits(0)(2).categories.contains(1.0)) - assert(splits(0)(2).categories.contains(0.0)) - assert(splits(0)(2).categories.contains(2.0)) - - assert(splits(0)(3) === null) - - assert(splits(1)(0).feature === 1) - assert(splits(1)(0).threshold === Double.MinValue) - assert(splits(1)(0).featureType === Categorical) - assert(splits(1)(0).categories.length === 1) - assert(splits(1)(0).categories.contains(0.0)) - - assert(splits(1)(1).feature === 1) - assert(splits(1)(1).threshold === Double.MinValue) - assert(splits(1)(1).featureType === Categorical) - assert(splits(1)(1).categories.length === 2) - assert(splits(1)(1).categories.contains(1.0)) - assert(splits(1)(1).categories.contains(0.0)) - - assert(splits(1)(2).feature === 1) - assert(splits(1)(2).threshold === Double.MinValue) - assert(splits(1)(2).featureType === Categorical) - assert(splits(1)(2).categories.length === 3) - assert(splits(1)(2).categories.contains(1.0)) - assert(splits(1)(2).categories.contains(0.0)) - assert(splits(1)(2).categories.contains(2.0)) - - assert(splits(1)(3) === null) - - // Check bins. - - assert(bins(0)(0).category === 1.0) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(1.0)) - - assert(bins(0)(1).category === 0.0) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).lowSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.contains(0.0)) - - assert(bins(0)(2).category === 2.0) - assert(bins(0)(2).lowSplit.categories.length === 2) - assert(bins(0)(2).lowSplit.categories.contains(1.0)) - assert(bins(0)(2).lowSplit.categories.contains(0.0)) - assert(bins(0)(2).highSplit.categories.length === 3) - assert(bins(0)(2).highSplit.categories.contains(1.0)) - assert(bins(0)(2).highSplit.categories.contains(0.0)) - assert(bins(0)(2).highSplit.categories.contains(2.0)) - - assert(bins(0)(3) === null) - - assert(bins(1)(0).category === 0.0) - assert(bins(1)(0).lowSplit.categories.length === 0) - assert(bins(1)(0).highSplit.categories.length === 1) - assert(bins(1)(0).highSplit.categories.contains(0.0)) - - assert(bins(1)(1).category === 1.0) - assert(bins(1)(1).lowSplit.categories.length === 1) - assert(bins(1)(1).lowSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.length === 2) - assert(bins(1)(1).highSplit.categories.contains(0.0)) - assert(bins(1)(1).highSplit.categories.contains(1.0)) - - assert(bins(1)(2).category === 2.0) - assert(bins(1)(2).lowSplit.categories.length === 2) - assert(bins(1)(2).lowSplit.categories.contains(0.0)) - assert(bins(1)(2).lowSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.length === 3) - assert(bins(1)(2).highSplit.categories.contains(0.0)) - assert(bins(1)(2).highSplit.categories.contains(1.0)) - assert(bins(1)(2).highSplit.categories.contains(2.0)) - - assert(bins(1)(3) === null) + assert(splits.length === 2) + assert(bins.length === 2) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) } test("extract categories from a number for multiclass classification") { @@ -270,8 +128,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq) } - test("split and bin calculations for unordered categorical variables with multiclass " + - "classification") { + test("Multiclass classification with unordered categorical features:" + + " split and bin calculations") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -282,8 +140,15 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 3) + assert(bins(0).length === 6) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0) @@ -321,10 +186,6 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(1)(2).categories.contains(0.0)) assert(splits(1)(2).categories.contains(1.0)) - assert(splits(0)(3) === null) - assert(splits(1)(3) === null) - - // Check bins. assert(bins(0)(0).category === Double.MinValue) @@ -360,13 +221,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(1)(2).highSplit.categories.contains(1.0)) assert(bins(1)(2).highSplit.categories.contains(0.0)) - assert(bins(0)(3) === null) - assert(bins(1)(3) === null) - } - test("split and bin calculations for ordered categorical variables with multiclass " + - "classification") { + test("Multiclass classification with ordered categorical features: split and bin calculations") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() assert(arr.length === 3000) val rdd = sc.parallelize(arr) @@ -377,52 +234,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) + // 2^10 - 1 > 100, so categorical features will be ordered + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - - // 2^10 - 1 > 100, so categorical variables will be ordered - - assert(splits(0)(0).feature === 0) - assert(splits(0)(0).threshold === Double.MinValue) - assert(splits(0)(0).featureType === Categorical) - assert(splits(0)(0).categories.length === 1) - assert(splits(0)(0).categories.contains(1.0)) - - assert(splits(0)(1).feature === 0) - assert(splits(0)(1).threshold === Double.MinValue) - assert(splits(0)(1).featureType === Categorical) - assert(splits(0)(1).categories.length === 2) - assert(splits(0)(1).categories.contains(2.0)) - - assert(splits(0)(2).feature === 0) - assert(splits(0)(2).threshold === Double.MinValue) - assert(splits(0)(2).featureType === Categorical) - assert(splits(0)(2).categories.length === 3) - assert(splits(0)(2).categories.contains(2.0)) - assert(splits(0)(2).categories.contains(1.0)) - - assert(splits(0)(10) === null) - assert(splits(1)(10) === null) - - - // Check bins. - - assert(bins(0)(0).category === 1.0) - assert(bins(0)(0).lowSplit.categories.length === 0) - assert(bins(0)(0).highSplit.categories.length === 1) - assert(bins(0)(0).highSplit.categories.contains(1.0)) - assert(bins(0)(1).category === 2.0) - assert(bins(0)(1).lowSplit.categories.length === 1) - assert(bins(0)(1).highSplit.categories.length === 2) - assert(bins(0)(1).highSplit.categories.contains(1.0)) - assert(bins(0)(1).highSplit.categories.contains(2.0)) - - assert(bins(0)(10) === null) - + assert(splits.length === 2) + assert(bins.length === 2) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) } - test("classification stump with all categorical variables") { + test("Binary classification stump with ordered categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -433,15 +259,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) + assert(splits.length === 2) + assert(bins.length === 2) + // no bins or splits pre-computed for ordered categorical features + assert(splits(0).length === 0) + assert(bins(0).length === 0) + val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 - assert(split.categories.length === 1) - assert(split.categories.contains(1.0)) + assert(split.categories === List(1.0)) assert(split.featureType === Categorical) assert(split.threshold === Double.MinValue) @@ -452,7 +286,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.impurity > 0.2) } - test("regression stump with all categorical variables") { + test("Regression stump with 3-ary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -462,10 +296,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) val split = bestSplits(0)._1 @@ -480,7 +318,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(stats.impurity > 0.2) } - test("regression stump with categorical variables of arity 2") { + test("Regression stump with binary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -490,6 +328,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { maxDepth = 2, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) validateRegressor(model, arr, 0.0) @@ -497,12 +338,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) } - test("stump with fixed label 0 for Gini") { + test("Binary classification stump with fixed label 0 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) + val strategy = new Strategy(Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -512,7 +357,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -521,12 +366,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.rightImpurity === 0) } - test("stump with fixed label 1 for Gini") { + test("Binary classification stump with fixed label 1 for Gini") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Gini, 3, 2, 100) + val strategy = new Strategy(Classification, Gini, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -536,7 +385,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -546,12 +395,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } - test("stump with fixed label 0 for Entropy") { + test("Binary classification stump with fixed label 0 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) + val strategy = new Strategy(Classification, Entropy, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -561,7 +414,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -571,12 +424,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 0) } - test("stump with fixed label 1 for Entropy") { + test("Binary classification stump with fixed label 1 for Entropy") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) val rdd = sc.parallelize(arr) - val strategy = new Strategy(Classification, Entropy, 3, 2, 100) + val strategy = new Strategy(Classification, Entropy, maxDepth = 3, + numClassesForClassification = 2, maxBins = 100) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) assert(splits.length === 2) assert(splits(0).length === 99) @@ -586,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins(0).length === 100) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) assert(bestSplits(0)._1.feature === 0) @@ -596,7 +453,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplits(0)._2.predict === 1) } - test("second level node building with/without groups") { + test("Second level node building with vs. without groups") { val arr = DecisionTreeSuite.generateOrderedLabeledPoints() assert(arr.length === 1000) val rdd = sc.parallelize(arr) @@ -613,12 +470,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { // Train a 1-node model val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100) val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val nodes: Array[Node] = new Array[Node](7) - nodes(0) = modelOneNode.topNode - nodes(0).leftNode = None - nodes(0).rightNode = None + val nodes: Array[Node] = new Array[Node](8) + nodes(1) = modelOneNode.topNode + nodes(1).leftNode = None + nodes(1).rightNode = None - val parentImpurities = Array(0.5, 0.5, 0.5) + val parentImpurities = Array(0, 0.5, 0.5, 0.5) // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) @@ -648,16 +505,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } } - test("stump with categorical variables for multiclass classification") { + test("Multiclass classification stump with 3-ary (unordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(strategy.isMulticlassClassification) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -668,7 +528,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } - test("stump with 1 continuous variable for binary classification, to check off-by-1 error") { + test("Binary classification stump with 1 continuous feature, to check off-by-1 error") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0)) arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0)) @@ -684,26 +544,27 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(model.depth === 1) } - test("stump with 2 continuous variables for binary classification") { + test("Binary classification stump with 2 continuous features") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))) arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) assert(model.topNode.split.get.feature === 1) } - test("stump with categorical variables for multiclass classification, with just enough bins") { - val maxBins = math.pow(2, 3 - 1).toInt // just enough bins to allow unordered features + test("Multiclass classification stump with unordered categorical features," + + " with just enough bins") { + val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -711,6 +572,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3)) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) + assert(metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 1.0) @@ -719,7 +582,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -733,7 +596,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(gain.rightImpurity === 0) } - test("stump with continuous variables for multiclass classification") { + test("Multiclass classification stump with continuous features") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, @@ -746,7 +609,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -759,20 +622,21 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { } - test("stump with continuous + categorical variables for multiclass classification") { + test("Multiclass classification stump with continuous + unordered categorical features") { val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3)) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) validateClassifier(model, arr, 0.9) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -784,17 +648,19 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.threshold < 2020) } - test("stump with categorical variables for ordered multiclass classification") { + test("Multiclass classification stump with 10-ary (ordered) categorical features") { val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) assert(strategy.isMulticlassClassification) val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) + assert(!metadata.isUnordered(featureIndex = 0)) + assert(!metadata.isUnordered(featureIndex = 1)) val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), metadata, 0, + val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0, new Array[Node](0), splits, bins, 10) assert(bestSplits.length === 1) @@ -805,6 +671,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bestSplit.featureType === Categorical) } + test("Multiclass classification tree with 10-ary (ordered) categorical features," + + " with just enough bins") { + val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures() + val rdd = sc.parallelize(arr) + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4, + numClassesForClassification = 3, maxBins = 10, + categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10)) + assert(strategy.isMulticlassClassification) + + val model = DecisionTree.train(rdd, strategy) + validateClassifier(model, arr, 0.6) + } } @@ -899,5 +777,4 @@ object DecisionTreeSuite { arr } - } diff --git a/pom.xml b/pom.xml index a5eaea80afd71..d05190512f742 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -221,6 +221,18 @@ false + + + spark-staging-1030 + Spark 1.1.0 Staging (1030) + https://repository.apache.org/content/repositories/orgapachespark-1030/ + + true + + + false + + diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 034ba6a7bf50f..0f5d71afcf616 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -85,7 +85,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.0.0" + val previousSparkVersion = "1.1.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a2f1b3582ab71..46b78bd5c7061 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -33,6 +33,18 @@ import com.typesafe.tools.mima.core._ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("1.2") => + Seq( + MimaBuild.excludeSparkPackage("deploy"), + MimaBuild.excludeSparkPackage("graphx") + ) ++ + // This is @DeveloperAPI, but Mima still gives false-positives: + MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ + Seq( + // This is @Experimental, but Mima still gives false-positives: + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.foreachAsync") + ) case v if v.startsWith("1.1") => Seq( MimaBuild.excludeSparkPackage("deploy"), @@ -111,6 +123,8 @@ object MimaExcludes { MimaBuild.excludeSparkClass("storage.Values") ++ MimaBuild.excludeSparkClass("storage.Entry") ++ MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++ + // Class was missing "@DeveloperApi" annotation in 1.0. + MimaBuild.excludeSparkClass("scheduler.SparkListenerApplicationStart") ++ Seq( ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Gini.calculate"), @@ -119,14 +133,14 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleMethTypeProblem]( "org.apache.spark.mllib.tree.impurity.Variance.calculate") ) ++ - Seq ( // Package-private classes removed in SPARK-2341 + Seq( // Package-private classes removed in SPARK-2341 ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.BinaryLabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.LabelParser$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.util.MulticlassLabelParser$") - ) ++ + ) ++ Seq( // package-private classes removed in MLlib ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.regression.GeneralizedLinearAlgorithm.org$apache$spark$mllib$regression$GeneralizedLinearAlgorithm$$prependOne") diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 4c696d3d385fb..45f6d2973ea90 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -184,7 +184,7 @@ object OldDeps { def versionArtifact(id: String): Option[sbt.ModuleID] = { val fullId = id + "_2.10" - Some("org.apache.spark" % fullId % "1.0.0") + Some("org.apache.spark" % fullId % "1.1.0") } def oldDepsSettings() = Defaults.defaultSettings ++ Seq( @@ -290,9 +290,9 @@ object Unidoc { publish := {}, unidocProjectFilter in(ScalaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, yarn, yarnAlpha), + inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha), // Skip class names containing $ and some internal packages in Javadocs unidocAllSources in (JavaUnidoc, unidoc) := { @@ -314,7 +314,7 @@ object Unidoc { "-group", "Core Java API", packageList("api.java", "api.java.function"), "-group", "Spark Streaming", packageList( "streaming.api.java", "streaming.flume", "streaming.kafka", - "streaming.mqtt", "streaming.twitter", "streaming.zeromq" + "streaming.mqtt", "streaming.twitter", "streaming.zeromq", "streaming.kinesis" ), "-group", "MLlib", packageList( "mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg", diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index c58555fc9d2c5..1a2e774738fe7 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -61,13 +61,17 @@ from pyspark.conf import SparkConf from pyspark.context import SparkContext -from pyspark.sql import SQLContext from pyspark.rdd import RDD -from pyspark.sql import SchemaRDD -from pyspark.sql import Row from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel +from pyspark.accumulators import Accumulator, AccumulatorParam +from pyspark.broadcast import Broadcast +from pyspark.serializers import MarshalSerializer, PickleSerializer +# for back compatibility +from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row -__all__ = ["SparkConf", "SparkContext", "SQLContext", "RDD", "SchemaRDD", - "SparkFiles", "StorageLevel", "Row"] +__all__ = [ + "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", + "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", +] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index f133cf6f7befc..ccbca67656c8d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -94,6 +94,9 @@ from pyspark.serializers import read_int, PickleSerializer +__all__ = ['Accumulator', 'AccumulatorParam'] + + pickleSer = PickleSerializer() # Holds accumulators registered on the current machine, keyed by ID. This is then used to send diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 675a2fcd2ff4e..5c7c9cc161dff 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -31,6 +31,10 @@ from pyspark.serializers import CompressedSerializer, PickleSerializer + +__all__ = ['Broadcast'] + + # Holds broadcasted data received from Java, keyed by its id. _broadcastRegistry = {} @@ -59,11 +63,20 @@ def __init__(self, bid, value, java_broadcast=None, """ self.bid = bid if path is None: - self.value = value + self._value = value self._jbroadcast = java_broadcast self._pickle_registry = pickle_registry self.path = path + @property + def value(self): + """ Return the broadcasted value + """ + if not hasattr(self, "_value") and self.path is not None: + ser = CompressedSerializer(PickleSerializer()) + self._value = ser.load_stream(open(self.path)).next() + return self._value + def unpersist(self, blocking=False): self._jbroadcast.unpersist(blocking) os.unlink(self.path) @@ -72,15 +85,6 @@ def __reduce__(self): self._pickle_registry.add(self) return (_from_id, (self.bid, )) - def __getattr__(self, item): - if item == 'value' and self.path is not None: - ser = CompressedSerializer(PickleSerializer()) - value = ser.load_stream(open(self.path)).next() - self.value = value - return value - - raise AttributeError(item) - if __name__ == "__main__": import doctest diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 68062483dedaa..80e51d1a583a0 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -657,7 +657,6 @@ def save_partial(self, obj): def save_file(self, obj): """Save a file""" import StringIO as pystringIO #we can't use cStringIO as it lacks the name attribute - from ..transport.adapter import SerializingAdapter if not hasattr(obj, 'name') or not hasattr(obj, 'mode'): raise pickle.PicklingError("Cannot pickle files that do not map to an actual file") @@ -691,13 +690,10 @@ def save_file(self, obj): tmpfile.close() if tst != '': raise pickle.PicklingError("Cannot pickle file %s as it does not appear to map to a physical, real file" % name) - elif fsize > SerializingAdapter.max_transmit_data: - raise pickle.PicklingError("Cannot pickle file %s as it exceeds cloudconf.py's max_transmit_data of %d" % - (name,SerializingAdapter.max_transmit_data)) else: try: tmpfile = file(name) - contents = tmpfile.read(SerializingAdapter.max_transmit_data) + contents = tmpfile.read() tmpfile.close() except IOError: raise pickle.PicklingError("Cannot pickle file %s as it cannot be read" % name) diff --git a/python/pyspark/conf.py b/python/pyspark/conf.py index fb716f6753a45..b64875a3f495a 100644 --- a/python/pyspark/conf.py +++ b/python/pyspark/conf.py @@ -54,6 +54,8 @@ (u'spark.executorEnv.VAR4', u'value4'), (u'spark.home', u'/path')] """ +__all__ = ['SparkConf'] + class SparkConf(object): diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6e4fdaa6eec9d..5a30431568b16 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -37,6 +37,9 @@ from py4j.java_collections import ListConverter +__all__ = ['SparkContext'] + + # These are special default configs for PySpark, they will overwrite # the default ones for Spark if they are not configured by user. DEFAULT_CONFIGS = { diff --git a/python/pyspark/files.py b/python/pyspark/files.py index 331de9a9b2212..797573f49dac8 100644 --- a/python/pyspark/files.py +++ b/python/pyspark/files.py @@ -18,6 +18,9 @@ import os +__all__ = ['SparkFiles'] + + class SparkFiles(object): """ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index ffdda7ee19302..71ab46b61d7fa 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -30,6 +30,10 @@ from math import exp, log +__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel', + 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] + + class LogisticRegressionModel(LinearModel): """A linear binary classification model derived from logistic regression. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index a0630d1d5c58b..f3e952a1d842a 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -25,6 +25,8 @@ _get_initial_weights, _serialize_rating, _regression_train_wrapper from pyspark.mllib.linalg import SparseVector +__all__ = ['KMeansModel', 'KMeans'] + class KMeansModel(object): diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index f485a69db1fa2..e69051c104e37 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -27,6 +27,9 @@ from numpy import array, array_equal, ndarray, float64, int32 +__all__ = ['SparseVector', 'Vectors'] + + class SparseVector(object): """ diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 4dc1a4a912421..3e59c73db85e3 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -25,6 +25,9 @@ from pyspark.serializers import NoOpSerializer +__all__ = ['RandomRDDs', ] + + class RandomRDDs: """ Generator methods for creating RDDs comprised of i.i.d samples from diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index e863fc249ec36..2df23394da6f8 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -24,6 +24,8 @@ _serialize_tuple, RatingDeserializer from pyspark.rdd import RDD +__all__ = ['MatrixFactorizationModel', 'ALS'] + class MatrixFactorizationModel(object): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index d8792cf44872f..f572dcfb840b6 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -17,15 +17,15 @@ from numpy import array, ndarray from pyspark import SparkContext -from pyspark.mllib._common import \ - _dot, _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ - _serialize_double_matrix, _deserialize_double_matrix, \ - _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ +from pyspark.mllib._common import _dot, _regression_train_wrapper, \ _linear_predictor_typecheck, _have_scipy, _scipy_issparse from pyspark.mllib.linalg import SparseVector, Vectors +__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel' + 'LinearRegressionWithSGD', 'LassoWithSGD', 'RidgeRegressionWithSGD'] + + class LabeledPoint(object): """ diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py index feef0d16cd644..8c726f171c978 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat.py @@ -21,8 +21,10 @@ from pyspark.mllib._common import \ _get_unmangled_double_vector_rdd, _get_unmangled_rdd, \ - _serialize_double, _serialize_double_vector, \ - _deserialize_double, _deserialize_double_matrix, _deserialize_double_vector + _serialize_double, _deserialize_double_matrix, _deserialize_double_vector + + +__all__ = ['MultivariateStatisticalSummary', 'Statistics'] class MultivariateStatisticalSummary(object): diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index e9d778df5a24b..a2fade61e9a71 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -26,6 +26,9 @@ from pyspark.serializers import NoOpSerializer +__all__ = ['DecisionTreeModel', 'DecisionTree'] + + class DecisionTreeModel(object): """ @@ -88,6 +91,7 @@ class DecisionTree(object): It will probably be modified for Spark v1.2. Example usage: + >>> from numpy import array >>> import sys >>> from pyspark.mllib.regression import LabeledPoint diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2d80fad796957..5667154cb84a8 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -48,6 +48,7 @@ from py4j.java_collections import ListConverter, MapConverter + __all__ = ["RDD"] @@ -62,7 +63,7 @@ def portable_hash(x): >>> portable_hash(None) 0 - >>> portable_hash((None, 1)) + >>> portable_hash((None, 1)) & 0xffffffff 219750521 """ if x is None: @@ -72,7 +73,7 @@ def portable_hash(x): for i in x: h ^= portable_hash(i) h *= 1000003 - h &= 0xffffffff + h &= sys.maxint h ^= len(x) if h == -1: h = -2 @@ -211,11 +212,16 @@ def cache(self): self.persist(StorageLevel.MEMORY_ONLY_SER) return self - def persist(self, storageLevel): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): """ Set this RDD's storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + + >>> rdd = sc.parallelize(["b", "a", "c"]) + >>> rdd.persist().is_cached + True """ self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) @@ -514,6 +520,30 @@ def __add__(self, other): raise TypeError return self.union(other) + def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=portable_hash, + ascending=True, keyfunc=lambda x: x): + """ + Repartition the RDD according to the given partitioner and, within each resulting partition, + sort records by their keys. + + >>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)]) + >>> rdd2 = rdd.repartitionAndSortWithinPartitions(2, lambda x: x % 2, 2) + >>> rdd2.glom().collect() + [[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]] + """ + if numPartitions is None: + numPartitions = self._defaultReducePartitions() + + spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true") + memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + serializer = self._jrdd_deserializer + + def sortPartition(iterator): + sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))) + + return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True) + def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): """ Sorts this RDD, which is assumed to consist of (key, value) pairs. @@ -1088,11 +1118,11 @@ def take(self, num): # we actually cap it at totalParts in runJob. numPartsToTry = 1 if partsScanned > 0: - # If we didn't find any rows after the first iteration, just - # try all partitions next. Otherwise, interpolate the number - # of partitions we need to try, but overestimate it by 50%. + # If we didn't find any rows after the previous iteration, + # quadruple and retry. Otherwise, interpolate the number of + # partitions we need to try, but overestimate it by 50%. if len(items) == 0: - numPartsToTry = totalParts - 1 + numPartsToTry = partsScanned * 4 else: numPartsToTry = int(1.5 * num * partsScanned / len(items)) @@ -1942,7 +1972,7 @@ def _is_pickled(self): return True return False - def _to_jrdd(self): + def _to_java_object_rdd(self): """ Return an JavaRDD of Object by unpickling It will convert each Python object into Java object by Pyrolite, whenever the @@ -1977,7 +2007,7 @@ def sumApprox(self, timeout, confidence=0.95): >>> (rdd.sumApprox(1000) - r) / r < 0.05 True """ - jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_jrdd() + jrdd = self.mapPartitions(lambda it: [float(sum(it))])._to_java_object_rdd() jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd()) r = jdrdd.sumApprox(timeout, confidence).getFinalValue() return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high()) @@ -1993,11 +2023,40 @@ def meanApprox(self, timeout, confidence=0.95): >>> (rdd.meanApprox(1000) - r) / r < 0.05 True """ - jrdd = self.map(float)._to_jrdd() + jrdd = self.map(float)._to_java_object_rdd() jdrdd = self.ctx._jvm.JavaDoubleRDD.fromRDD(jrdd.rdd()) r = jdrdd.meanApprox(timeout, confidence).getFinalValue() return BoundedFloat(r.mean(), r.confidence(), r.low(), r.high()) + def countApproxDistinct(self, relativeSD=0.05): + """ + :: Experimental :: + Return approximate number of distinct elements in the RDD. + + The algorithm used is based on streamlib's implementation of + "HyperLogLog in Practice: Algorithmic Engineering of a State + of The Art Cardinality Estimation Algorithm", available + here. + + @param relativeSD Relative accuracy. Smaller values create + counters that require more space. + It must be greater than 0.000017. + + >>> n = sc.parallelize(range(1000)).map(str).countApproxDistinct() + >>> 950 < n < 1050 + True + >>> n = sc.parallelize([i % 20 for i in range(1000)]).countApproxDistinct() + >>> 18 < n < 22 + True + """ + if relativeSD < 0.000017: + raise ValueError("relativeSD should be greater than 0.000017") + if relativeSD > 0.37: + raise ValueError("relativeSD should be smaller than 0.37") + # the hash space in Java is 2^32 + hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) + return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) + class PipelinedRDD(RDD): @@ -2040,6 +2099,7 @@ def pipeline_func(split, iterator): self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._id = None self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None @@ -2070,6 +2130,11 @@ def _jrdd(self): self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val + def id(self): + if self._id is None: + self._id = self._jrdd.id() + return self._id + def _is_pipelinable(self): return not (self.is_cached or self.is_checkpointed) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fc49aa42dbaf9..55e6cf3308611 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -409,7 +409,7 @@ def loads(self, obj): class CompressedSerializer(FramedSerializer): """ - compress the serialized data + Compress the serialized data """ def __init__(self, serializer): diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index e1e7cd954189f..89cf76920e353 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -28,6 +28,7 @@ sys.exit(1) +import atexit import os import platform import pyspark @@ -42,14 +43,15 @@ SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) sc = SparkContext(appName="PySparkShell", pyFiles=add_files) +atexit.register(lambda: sc.stop()) print("""Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ - /__ / .__/\_,_/_/ /_/\_\ version 1.0.0-SNAPSHOT + /__ / .__/\_,_/_/ /_/\_\ version %s /_/ -""") +""" % sc.version) print("Using Python version %s (%s, %s)" % ( platform.python_version(), platform.python_build()[0], diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 0ff6a548a85f1..004d4937cbe1c 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -29,6 +29,7 @@ from pyspark.rdd import RDD, PipelinedRDD from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer +from pyspark.storagelevel import StorageLevel from itertools import chain, ifilter, imap @@ -40,8 +41,7 @@ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", - "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "SchemaRDD", "Row"] class DataType(object): @@ -901,7 +901,7 @@ def __reduce__(self): class SQLContext: - """Main entry point for SparkSQL functionality. + """Main entry point for Spark SQL functionality. A SQLContext can be used create L{SchemaRDD}s, register L{SchemaRDD}s as tables, execute SQL over tables, cache tables, and read parquet files. @@ -943,18 +943,16 @@ def __init__(self, sparkContext, sqlContext=None): self._jsc = self._sc._jsc self._jvm = self._sc._jvm self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray - - if sqlContext: - self._scala_SQLContext = sqlContext + self._scala_SQLContext = sqlContext @property def _ssql_ctx(self): - """Accessor for the JVM SparkSQL context. + """Accessor for the JVM Spark SQL context. Subclasses can override this property to provide their own JVM Contexts. """ - if not hasattr(self, '_scala_SQLContext'): + if self._scala_SQLContext is None: self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext @@ -971,23 +969,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() [Row(c0=4)] - >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) - >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() - [Row(c0=5)] """ func = lambda _, it: imap(lambda x: f(*x), it) command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) + pickled_command = CloudPickleSerializer().dumps(command) + broadcast_vars = ListConverter().convert( + [x._jbroadcast for x in self._sc._pickled_broadcast_vars], + self._sc._gateway._gateway_client) + self._sc._pickled_broadcast_vars.clear() env = MapConverter().convert(self._sc.environment, self._sc._gateway._gateway_client) includes = ListConverter().convert(self._sc._python_includes, self._sc._gateway._gateway_client) self._ssql_ctx.registerPython(name, - bytearray(CloudPickleSerializer().dumps(command)), + bytearray(pickled_command), env, includes, self._sc.pythonExec, + broadcast_vars, self._sc._javaAccumulator, str(returnType)) @@ -1037,7 +1038,7 @@ def inferSchema(self, rdd): "can not infer schema") if type(first) is dict: warnings.warn("Using RDD of dict to inferSchema is deprecated," - "please use pyspark.Row instead") + "please use pyspark.sql.Row instead") schema = _infer_schema(first) rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema)) @@ -1487,12 +1488,27 @@ def __repr__(self): return "" % ", ".join(self) +def inherit_doc(cls): + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls + + +@inherit_doc class SchemaRDD(RDD): """An RDD of L{Row} objects that has an associated schema. The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can - utilize the relational query api exposed by SparkSQL. + utilize the relational query api exposed by Spark SQL. For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the L{SchemaRDD} is not operated on directly, as it's underlying @@ -1509,7 +1525,7 @@ def __init__(self, jschema_rdd, sql_ctx): self.sql_ctx = sql_ctx self._sc = sql_ctx._sc self._jschema_rdd = jschema_rdd - + self._id = None self.is_cached = False self.is_checkpointed = False self.ctx = self.sql_ctx._sc @@ -1527,9 +1543,10 @@ def _jrdd(self): self._lazy_jrdd = self._jschema_rdd.javaToPython() return self._lazy_jrdd - @property - def _id(self): - return self._jrdd.id() + def id(self): + if self._id is None: + self._id = self._jrdd.id() + return self._id def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. @@ -1563,6 +1580,7 @@ def registerTempTable(self, name): self._jschema_rdd.registerTempTable(name) def registerAsTable(self, name): + """DEPRECATED: use registerTempTable() instead""" warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) self.registerTempTable(name) @@ -1649,7 +1667,7 @@ def cache(self): self._jschema_rdd.cache() return self - def persist(self, storageLevel): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): self.is_cached = True javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) self._jschema_rdd.persist(javaStorageLevel) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 3e7040eade1ab..0bd2a9e6c507d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -43,6 +43,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter +from pyspark.sql import SQLContext, IntegerType _have_scipy = False _have_numpy = False @@ -168,6 +169,17 @@ def test_namedtuple(self): self.assertEquals(p1, p2) +# Regression test for SPARK-3415 +class CloudPickleTest(unittest.TestCase): + def test_pickling_file_handles(self): + from pyspark.cloudpickle import dumps + from StringIO import StringIO + from pickle import load + out1 = sys.stderr + out2 = load(StringIO(dumps(out1))) + self.assertEquals(out1, out2) + + class PySparkTestCase(unittest.TestCase): def setUp(self): @@ -280,6 +292,15 @@ def func(): class TestRDDFunctions(PySparkTestCase): + def test_id(self): + rdd = self.sc.parallelize(range(10)) + id = rdd.id() + self.assertEqual(id, rdd.id()) + rdd2 = rdd.map(str).filter(bool) + id2 = rdd2.id() + self.assertEqual(id + 1, id2) + self.assertEqual(id2, rdd2.id()) + def test_failed_sparkcontext_creation(self): # Regression test for SPARK-1550 self.sc.stop() @@ -404,6 +425,22 @@ def test_zip_with_different_number_of_items(self): self.assertEquals(a.count(), b.count()) self.assertRaises(Exception, lambda: a.zip(b).count()) + def test_count_approx_distinct(self): + rdd = self.sc.parallelize(range(1000)) + self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050) + self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050) + self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050) + self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050) + + rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7) + self.assertTrue(18 < rdd.countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22) + self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22) + + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001)) + self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.5)) + def test_histogram(self): # empty rdd = self.sc.parallelize([]) @@ -508,6 +545,35 @@ def test_histogram(self): self.assertEquals(([1, "b"], [5]), rdd.histogram(1)) self.assertRaises(TypeError, lambda: rdd.histogram(2)) + def test_repartitionAndSortWithinPartitions(self): + rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2) + + repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2) + partitions = repartitioned.glom().collect() + self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)]) + self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)]) + + +class TestSQL(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.sqlCtx = SQLContext(self.sc) + + def test_udf(self): + self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_broadcast_in_udf(self): + bar = {"a": "aa", "b": "bb", "c": "abc"} + foo = self.sc.broadcast(bar) + self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '') + [res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect() + self.assertEqual("abc", res[0]) + [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() + self.assertEqual("", res[0]) + class TestIO(PySparkTestCase): diff --git a/python/run-tests b/python/run-tests index 7b1ee3e1cddba..d98840de59d2c 100755 --- a/python/run-tests +++ b/python/run-tests @@ -19,7 +19,7 @@ # Figure out where the Spark framework is installed -FWDIR="$(cd `dirname $0`; cd ../; pwd)" +FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" # CD into the python directory to find things on the right path cd "$FWDIR/python" @@ -28,12 +28,14 @@ FAILED=0 rm -f unit-tests.log -# Remove the metastore and warehouse directory created by the HiveContext tests in SparkSQL +# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL rm -rf metastore warehouse function run_test() { echo "Running test: $1" - SPARK_TESTING=1 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log + + SPARK_TESTING=1 "$FWDIR"/bin/pyspark $1 2>&1 | tee -a unit-tests.log + FAILED=$((PIPESTATUS[0]||$FAILED)) # Fail and exit on the first test failure. @@ -48,6 +50,8 @@ function run_test() { echo "Running PySpark tests. Output is in python/unit-tests.log." +export PYSPARK_PYTHON="python" + # Try to test with Python 2.6, since that's the minimum version that we support: if [ $(which python2.6) ]; then export PYSPARK_PYTHON="python2.6" diff --git a/repl/pom.xml b/repl/pom.xml index 68f4504450778..fcc5f90d870e8 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 910b31d209e13..7667a9c11979e 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -14,6 +14,8 @@ import scala.reflect.internal.util.Position import scala.util.control.Exception.ignoring import scala.tools.nsc.util.stackTraceString +import org.apache.spark.SPARK_VERSION + /** * Machinery for the asynchronous initialization of the repl. */ @@ -26,9 +28,9 @@ trait SparkILoopInit { ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ - /___/ .__/\_,_/_/ /_/\_\ version 1.0.0-SNAPSHOT + /___/ .__/\_,_/_/ /_/\_\ version %s /_/ -""") +""".format(SPARK_VERSION)) import Properties._ val welcomeMsg = "Using Scala %s (%s, Java %s)".format( versionString, javaVmName, javaVersion) diff --git a/sbin/slaves.sh b/sbin/slaves.sh index f89547fef9e46..1d4dc5edf9858 100755 --- a/sbin/slaves.sh +++ b/sbin/slaves.sh @@ -36,29 +36,29 @@ if [ $# -le 0 ]; then exit 1 fi -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" . "$sbin/spark-config.sh" # If the slaves file is specified in the command line, # then it takes precedence over the definition in # spark-env.sh. Save it here. -HOSTLIST=$SPARK_SLAVES +HOSTLIST="$SPARK_SLAVES" # Check if --config is passed as an argument. It is an optional parameter. # Exit if the argument is not a directory. if [ "$1" == "--config" ] then shift - conf_dir=$1 + conf_dir="$1" if [ ! -d "$conf_dir" ] then echo "ERROR : $conf_dir is not a directory" echo $usage exit 1 else - export SPARK_CONF_DIR=$conf_dir + export SPARK_CONF_DIR="$conf_dir" fi shift fi @@ -79,7 +79,7 @@ if [ "$SPARK_SSH_OPTS" = "" ]; then fi for slave in `cat "$HOSTLIST"|sed "s/#.*$//;/^$/d"`; do - ssh $SPARK_SSH_OPTS $slave $"${@// /\\ }" \ + ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ 2>&1 | sed "s/^/$slave: /" & if [ "$SPARK_SLAVE_SLEEP" != "" ]; then sleep $SPARK_SLAVE_SLEEP diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index 5c87da5815b64..2718d6cba1c9a 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -21,19 +21,19 @@ # resolve links - $0 may be a softlink this="${BASH_SOURCE-$0}" -common_bin=$(cd -P -- "$(dirname -- "$this")" && pwd -P) +common_bin="$(cd -P -- "$(dirname -- "$this")" && pwd -P)" script="$(basename -- "$this")" this="$common_bin/$script" # convert relative path to absolute path -config_bin=`dirname "$this"` -script=`basename "$this"` -config_bin=`cd "$config_bin"; pwd` +config_bin="`dirname "$this"`" +script="`basename "$this"`" +config_bin="`cd "$config_bin"; pwd`" this="$config_bin/$script" -export SPARK_PREFIX=`dirname "$this"`/.. -export SPARK_HOME=${SPARK_PREFIX} +export SPARK_PREFIX="`dirname "$this"`"/.. +export SPARK_HOME="${SPARK_PREFIX}" export SPARK_CONF_DIR="$SPARK_HOME/conf" # Add the PySpark classes to the PYTHONPATH: -export PYTHONPATH=$SPARK_HOME/python:$PYTHONPATH -export PYTHONPATH=$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH +export PYTHONPATH="$SPARK_HOME/python:$PYTHONPATH" +export PYTHONPATH="$SPARK_HOME/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh index 9032f23ea8eff..bd476b400e1c3 100755 --- a/sbin/spark-daemon.sh +++ b/sbin/spark-daemon.sh @@ -37,8 +37,8 @@ if [ $# -le 1 ]; then exit 1 fi -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" . "$sbin/spark-config.sh" @@ -50,14 +50,14 @@ sbin=`cd "$sbin"; pwd` if [ "$1" == "--config" ] then shift - conf_dir=$1 + conf_dir="$1" if [ ! -d "$conf_dir" ] then echo "ERROR : $conf_dir is not a directory" echo $usage exit 1 else - export SPARK_CONF_DIR=$conf_dir + export SPARK_CONF_DIR="$conf_dir" fi shift fi @@ -100,12 +100,12 @@ if [ "$SPARK_LOG_DIR" = "" ]; then export SPARK_LOG_DIR="$SPARK_HOME/logs" fi mkdir -p "$SPARK_LOG_DIR" -touch $SPARK_LOG_DIR/.spark_test > /dev/null 2>&1 +touch "$SPARK_LOG_DIR"/.spark_test > /dev/null 2>&1 TEST_LOG_DIR=$? if [ "${TEST_LOG_DIR}" = "0" ]; then - rm -f $SPARK_LOG_DIR/.spark_test + rm -f "$SPARK_LOG_DIR"/.spark_test else - chown $SPARK_IDENT_STRING $SPARK_LOG_DIR + chown "$SPARK_IDENT_STRING" "$SPARK_LOG_DIR" fi if [ "$SPARK_PID_DIR" = "" ]; then @@ -113,8 +113,8 @@ if [ "$SPARK_PID_DIR" = "" ]; then fi # some variables -log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out -pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid +log="$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out" +pid="$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid" # Set default scheduling priority if [ "$SPARK_NICENESS" = "" ]; then @@ -136,7 +136,7 @@ case $startStop in fi if [ "$SPARK_MASTER" != "" ]; then - echo rsync from $SPARK_MASTER + echo rsync from "$SPARK_MASTER" rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' $SPARK_MASTER/ "$SPARK_HOME" fi diff --git a/sbin/spark-executor b/sbin/spark-executor index 3621321a9bc8d..674ce906d9421 100755 --- a/sbin/spark-executor +++ b/sbin/spark-executor @@ -17,10 +17,10 @@ # limitations under the License. # -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" -export PYTHONPATH=$FWDIR/python:$PYTHONPATH -export PYTHONPATH=$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH +export PYTHONPATH="$FWDIR/python:$PYTHONPATH" +export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH" echo "Running spark-executor with framework dir = $FWDIR" -exec $FWDIR/bin/spark-class org.apache.spark.executor.MesosExecutorBackend +exec "$FWDIR"/bin/spark-class org.apache.spark.executor.MesosExecutorBackend diff --git a/sbin/start-all.sh b/sbin/start-all.sh index 5c89ab4d86b3a..1baf57cea09ee 100755 --- a/sbin/start-all.sh +++ b/sbin/start-all.sh @@ -21,8 +21,8 @@ # Starts the master on this node. # Starts a worker on each node specified in conf/slaves -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" TACHYON_STR="" diff --git a/sbin/start-history-server.sh b/sbin/start-history-server.sh index 580ab471b8a79..7172ad15d88fc 100755 --- a/sbin/start-history-server.sh +++ b/sbin/start-history-server.sh @@ -24,8 +24,8 @@ # Use the SPARK_HISTORY_OPTS environment variable to set history server configuration. # -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" . "$sbin/spark-config.sh" . "$SPARK_PREFIX/bin/load-spark-env.sh" diff --git a/sbin/start-master.sh b/sbin/start-master.sh index c5c02491f78e1..17fff58f4f768 100755 --- a/sbin/start-master.sh +++ b/sbin/start-master.sh @@ -19,8 +19,8 @@ # Starts the master on the machine this script is executed on. -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" START_TACHYON=false diff --git a/sbin/start-slave.sh b/sbin/start-slave.sh index b563400dc24f3..2fc35309f4ca5 100755 --- a/sbin/start-slave.sh +++ b/sbin/start-slave.sh @@ -20,7 +20,7 @@ # Usage: start-slave.sh # where is like "spark://localhost:7077" -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" "$sbin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@" diff --git a/sbin/start-slaves.sh b/sbin/start-slaves.sh index 4912d0c0c7dfd..ba1a84abc1fef 100755 --- a/sbin/start-slaves.sh +++ b/sbin/start-slaves.sh @@ -17,8 +17,8 @@ # limitations under the License. # -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" START_TACHYON=false @@ -46,11 +46,11 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then fi if [ "$SPARK_MASTER_IP" = "" ]; then - SPARK_MASTER_IP=`hostname` + SPARK_MASTER_IP="`hostname`" fi if [ "$START_TACHYON" == "true" ]; then - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon bootstrap-conf $SPARK_MASTER_IP + "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin"/../tachyon/bin/tachyon bootstrap-conf "$SPARK_MASTER_IP" # set -t so we can call sudo SPARK_SSH_OPTS="-o StrictHostKeyChecking=no -t" "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/../tachyon/bin/tachyon-start.sh" worker SudoMount \; sleep 1 @@ -58,12 +58,12 @@ fi # Launch the slaves if [ "$SPARK_WORKER_INSTANCES" = "" ]; then - exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT + exec "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" 1 "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" else if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then SPARK_WORKER_WEBUI_PORT=8081 fi for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do - "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i )) + "$sbin/slaves.sh" cd "$SPARK_HOME" \; "$sbin/start-slave.sh" $(( $i + 1 )) "spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT" --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i )) done fi diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index c519a77df4a14..4ce40fe750384 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -24,7 +24,7 @@ set -o posix # Figure out where Spark is installed -FWDIR="$(cd `dirname $0`/..; pwd)" +FWDIR="$(cd "`dirname "$0"`"/..; pwd)" CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" CLASS_NOT_FOUND_EXIT_STATUS=1 @@ -38,10 +38,10 @@ function usage { pattern+="\|=======" pattern+="\|--help" - $FWDIR/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 + "$FWDIR"/bin/spark-submit --help 2>&1 | grep -v Usage 1>&2 echo echo "Thrift server options:" - $FWDIR/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 + "$FWDIR"/bin/spark-class $CLASS --help 2>&1 | grep -v "$pattern" 1>&2 } if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then @@ -49,7 +49,7 @@ if [[ "$@" = *--help ]] || [[ "$@" = *-h ]]; then exit 0 fi -source $FWDIR/bin/utils.sh +source "$FWDIR"/bin/utils.sh SUBMIT_USAGE_FUNCTION=usage gatherSparkSubmitOpts "$@" diff --git a/sbin/stop-all.sh b/sbin/stop-all.sh index 60b358d374565..298c6a9859795 100755 --- a/sbin/stop-all.sh +++ b/sbin/stop-all.sh @@ -21,8 +21,8 @@ # Run this on the master nde -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" # Load the Spark configuration . "$sbin/spark-config.sh" diff --git a/sbin/stop-history-server.sh b/sbin/stop-history-server.sh index c0034ad641cbe..6e6056359510f 100755 --- a/sbin/stop-history-server.sh +++ b/sbin/stop-history-server.sh @@ -19,7 +19,7 @@ # Stops the history server on the machine this script is executed on. -sbin=`dirname "$0"` -sbin=`cd "$sbin"; pwd` +sbin="`dirname "$0"`" +sbin="`cd "$sbin"; pwd`" "$sbin"/spark-daemon.sh stop org.apache.spark.deploy.history.HistoryServer 1 diff --git a/sbt/sbt b/sbt/sbt index 1b1aa1483a829..c172fa74bc771 100755 --- a/sbt/sbt +++ b/sbt/sbt @@ -3,32 +3,32 @@ # When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so # that we can run Hive to generate the golden answer. This is not required for normal development # or testing. -for i in $HIVE_HOME/lib/* -do HADOOP_CLASSPATH=$HADOOP_CLASSPATH:$i +for i in "$HIVE_HOME"/lib/* +do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i" done export HADOOP_CLASSPATH realpath () { ( - TARGET_FILE=$1 + TARGET_FILE="$1" - cd $(dirname $TARGET_FILE) - TARGET_FILE=$(basename $TARGET_FILE) + cd "$(dirname "$TARGET_FILE")" + TARGET_FILE="$(basename "$TARGET_FILE")" COUNT=0 while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ] do - TARGET_FILE=$(readlink $TARGET_FILE) - cd $(dirname $TARGET_FILE) - TARGET_FILE=$(basename $TARGET_FILE) + TARGET_FILE="$(readlink "$TARGET_FILE")" + cd $(dirname "$TARGET_FILE") + TARGET_FILE="$(basename $TARGET_FILE)" COUNT=$(($COUNT + 1)) done - echo $(pwd -P)/$TARGET_FILE + echo "$(pwd -P)/"$TARGET_FILE"" ) } -. $(dirname $(realpath $0))/sbt-launch-lib.bash +. "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy" diff --git a/sbt/sbt-launch-lib.bash b/sbt/sbt-launch-lib.bash index c91fecf024ad4..7f05d2ef491a3 100755 --- a/sbt/sbt-launch-lib.bash +++ b/sbt/sbt-launch-lib.bash @@ -7,7 +7,7 @@ # TODO - Should we merge the main SBT script with this library? if test -z "$HOME"; then - declare -r script_dir="$(dirname $script_path)" + declare -r script_dir="$(dirname "$script_path")" else declare -r script_dir="$HOME/.sbt" fi @@ -46,20 +46,20 @@ acquire_sbt_jar () { if [[ ! -f "$sbt_jar" ]]; then # Download sbt launch jar if it hasn't been downloaded yet - if [ ! -f ${JAR} ]; then + if [ ! -f "${JAR}" ]; then # Download printf "Attempting to fetch sbt\n" - JAR_DL=${JAR}.part + JAR_DL="${JAR}.part" if hash curl 2>/dev/null; then - (curl --progress-bar ${URL1} > ${JAR_DL} || curl --progress-bar ${URL2} > ${JAR_DL}) && mv ${JAR_DL} ${JAR} + (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" elif hash wget 2>/dev/null; then - (wget --progress=bar ${URL1} -O ${JAR_DL} || wget --progress=bar ${URL2} -O ${JAR_DL}) && mv ${JAR_DL} ${JAR} + (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" exit -1 fi fi - if [ ! -f ${JAR} ]; then + if [ ! -f "${JAR}" ]; then # We failed to download printf "Our attempt to download sbt locally to ${JAR} failed. Please install sbt manually from http://www.scala-sbt.org/\n" exit -1 diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 830711a46a35b..0d756f873e486 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala new file mode 100644 index 0000000000000..8364379644c90 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +/** + * Builds a map that is keyed by an Attribute's expression id. Using the expression id allows values + * to be looked up even when the attributes used differ cosmetically (i.e., the capitalization + * of the name, or the expected nullability). + */ +object AttributeMap { + def apply[A](kvs: Seq[(Attribute, A)]) = + new AttributeMap(kvs.map(kv => (kv._1.exprId, (kv._1, kv._2))).toMap) +} + +class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) + extends Map[Attribute, A] with Serializable { + + override def get(k: Attribute): Option[A] = baseMap.get(k.exprId).map(_._2) + + override def + [B1 >: A](kv: (Attribute, B1)): Map[Attribute, B1] = + (baseMap.map(_._2) + kv).toMap + + override def iterator: Iterator[(Attribute, A)] = baseMap.map(_._2).iterator + + override def -(key: Attribute): Map[Attribute, A] = (baseMap.map(_._2) - key).toMap +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 54c6baf1af3bf..fa80b07f8e6be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -38,12 +38,20 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } object BindReferences extends Logging { - def bindReference[A <: Expression](expression: A, input: Seq[Attribute]): A = { + + def bindReference[A <: Expression]( + expression: A, + input: Seq[Attribute], + allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { val ordinal = input.indexWhere(_.exprId == a.exprId) if (ordinal == -1) { - sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + if (allowFailures) { + a + } else { + sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + } } else { BoundReference(ordinal, a.dataType, a.nullable) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 56f042891a2e6..f988fb010b107 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { case class Sqrt(child: Expression) extends UnaryExpression { type EvaluatedType = Any - def dataType = child.dataType + def dataType = DoubleType override def foldable = child.foldable def nullable = child.nullable override def toString = s"SQRT($child)" diff --git a/sql/core/pom.xml b/sql/core/pom.xml index c8016e41256d5..bd110218d34f7 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 64d49354dadcd..f6f4cf3b80d41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -26,6 +26,7 @@ import java.util.Properties private[spark] object SQLConf { val COMPRESS_CACHED = "spark.sql.inMemoryColumnarStorage.compressed" val COLUMN_BATCH_SIZE = "spark.sql.inMemoryColumnarStorage.batchSize" + val IN_MEMORY_PARTITION_PRUNING = "spark.sql.inMemoryColumnarStorage.partitionPruning" val AUTO_BROADCASTJOIN_THRESHOLD = "spark.sql.autoBroadcastJoinThreshold" val DEFAULT_SIZE_IN_BYTES = "spark.sql.defaultSizeInBytes" val SHUFFLE_PARTITIONS = "spark.sql.shuffle.partitions" @@ -52,7 +53,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -trait SQLConf { +private[sql] trait SQLConf { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -124,6 +125,12 @@ trait SQLConf { private[spark] def isParquetBinaryAsString: Boolean = getConf(PARQUET_BINARY_AS_STRING, "false").toBoolean + /** + * When set to true, partition pruning for in-memory columnar tables is enabled. + */ + private[spark] def inMemoryPartitionPruning: Boolean = + getConf(IN_MEMORY_PARTITION_PRUNING, "false").toBoolean + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a75af94d29303..5acb45c155ba5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -272,7 +272,7 @@ class SQLContext(@transient val sparkContext: SparkContext) val currentTable = table(tableName).queryExecution.analyzed val asInMemoryRelation = currentTable match { case _: InMemoryRelation => - currentTable.logicalPlan + currentTable case _ => InMemoryRelation(useCompression, columnBatchSize, executePlan(currentTable).executedPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 0b48e9e659faa..595b4aa36eae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.util.{List => JList, Map => JMap} import org.apache.spark.Accumulator +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} import org.apache.spark.sql.execution.PythonUDF @@ -29,7 +30,7 @@ import scala.reflect.runtime.universe.{TypeTag, typeTag} /** * Functions for registering scala lambda functions as UDFs in a SQLContext. */ -protected[sql] trait UDFRegistration { +private[sql] trait UDFRegistration { self: SQLContext => private[spark] def registerPython( @@ -38,6 +39,7 @@ protected[sql] trait UDFRegistration { envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], stringDataType: String): Unit = { log.debug( @@ -61,6 +63,7 @@ protected[sql] trait UDFRegistration { envVars, pythonIncludes, pythonExec, + broadcastVars, accumulator, dataType, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 6c67934bda5b8..e9d04ce7aae4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -25,7 +25,7 @@ import scala.math.BigDecimal import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** - * A result row from a SparkSQL query. + * A result row from a Spark SQL query. */ class Row(private[spark] val row: ScalaRow) extends Serializable { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 7e7bb2859bbcd..b3ec5ded22422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -38,7 +38,7 @@ private[sql] trait ColumnBuilder { /** * Column statistics information */ - def columnStats: ColumnStats[_, _] + def columnStats: ColumnStats /** * Returns the final columnar byte buffer. @@ -47,7 +47,7 @@ private[sql] trait ColumnBuilder { } private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( - val columnStats: ColumnStats[T, JvmType], + val columnStats: ColumnStats, val columnType: ColumnType[T, JvmType]) extends ColumnBuilder { @@ -75,25 +75,24 @@ private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( } override def build() = { - buffer.limit(buffer.position()).rewind() - buffer + buffer.flip().asInstanceOf[ByteBuffer] } } private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder private[sql] abstract class NativeColumnBuilder[T <: NativeType]( - override val columnStats: NativeColumnStats[T], + override val columnStats: ColumnStats, override val columnType: NativeColumnType[T]) extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType) with NullableColumnBuilder with AllCompressionSchemes with CompressibleColumnBuilder[T] -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) +private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new NoopColumnStats, BOOLEAN) private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) @@ -129,7 +128,6 @@ private[sql] object ColumnBuilder { val newSize = capacity + size.max(capacity / 8 + 1) val pos = orig.position() - orig.clear() ByteBuffer .allocate(newSize) .order(ByteOrder.nativeOrder()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 6502110e903fe..fc343ccb995c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -17,381 +17,193 @@ package org.apache.spark.sql.columnar +import java.sql.Timestamp + import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.types._ +private[sql] class ColumnStatisticsSchema(a: Attribute) extends Serializable { + val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = false)() + val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = false)() + val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() + + val schema = Seq(lowerBound, upperBound, nullCount) +} + +private[sql] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { + val (forAttribute, schema) = { + val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) + (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) + } +} + /** * Used to collect statistical information when building in-memory columns. * * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` * brings significant performance penalty. */ -private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable { - /** - * Closed lower bound of this column. - */ - def lowerBound: JvmType - - /** - * Closed upper bound of this column. - */ - def upperBound: JvmType - +private[sql] sealed trait ColumnStats extends Serializable { /** * Gathers statistics information from `row(ordinal)`. */ - def gatherStats(row: Row, ordinal: Int) - - /** - * Returns `true` if `lower <= row(ordinal) <= upper`. - */ - def contains(row: Row, ordinal: Int): Boolean + def gatherStats(row: Row, ordinal: Int): Unit /** - * Returns `true` if `row(ordinal) < upper` holds. + * Column statistics represented as a single row, currently including closed lower bound, closed + * upper bound and null count. */ - def isAbove(row: Row, ordinal: Int): Boolean - - /** - * Returns `true` if `lower < row(ordinal)` holds. - */ - def isBelow(row: Row, ordinal: Int): Boolean - - /** - * Returns `true` if `row(ordinal) <= upper` holds. - */ - def isAtOrAbove(row: Row, ordinal: Int): Boolean - - /** - * Returns `true` if `lower <= row(ordinal)` holds. - */ - def isAtOrBelow(row: Row, ordinal: Int): Boolean -} - -private[sql] sealed abstract class NativeColumnStats[T <: NativeType] - extends ColumnStats[T, T#JvmType] { - - type JvmType = T#JvmType - - protected var (_lower, _upper) = initialBounds - - def initialBounds: (JvmType, JvmType) - - protected def columnType: NativeColumnType[T] - - override def lowerBound: T#JvmType = _lower - - override def upperBound: T#JvmType = _upper - - override def isAtOrAbove(row: Row, ordinal: Int) = { - contains(row, ordinal) || isAbove(row, ordinal) - } - - override def isAtOrBelow(row: Row, ordinal: Int) = { - contains(row, ordinal) || isBelow(row, ordinal) - } + def collectedStatistics: Row } -private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] { - override def isAtOrBelow(row: Row, ordinal: Int) = true - - override def isAtOrAbove(row: Row, ordinal: Int) = true - - override def isBelow(row: Row, ordinal: Int) = true - - override def isAbove(row: Row, ordinal: Int) = true +private[sql] class NoopColumnStats extends ColumnStats { - override def contains(row: Row, ordinal: Int) = true + override def gatherStats(row: Row, ordinal: Int): Unit = {} - override def gatherStats(row: Row, ordinal: Int) {} - - override def upperBound = null.asInstanceOf[JvmType] - - override def lowerBound = null.asInstanceOf[JvmType] + override def collectedStatistics = Row() } -private[sql] abstract class BasicColumnStats[T <: NativeType]( - protected val columnType: NativeColumnType[T]) - extends NativeColumnStats[T] - -private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) { - override def initialBounds = (true, false) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class ByteColumnStats extends ColumnStats { + var upper = Byte.MinValue + var lower = Byte.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } -} - -private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) { - override def initialBounds = (Byte.MaxValue, Byte.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound + if (!row.isNullAt(ordinal)) { + val value = row.getByte(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) { - override def initialBounds = (Short.MaxValue, Short.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class ShortColumnStats extends ColumnStats { + var upper = Short.MinValue + var lower = Short.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } -} - -private[sql] class LongColumnStats extends BasicColumnStats(LONG) { - override def initialBounds = (Long.MaxValue, Long.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound + if (!row.isNullAt(ordinal)) { + val value = row.getShort(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) { - override def initialBounds = (Double.MaxValue, Double.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) - } - - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } - - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class LongColumnStats extends ColumnStats { + var upper = Long.MinValue + var lower = Long.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - } -} - -private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) { - override def initialBounds = (Float.MaxValue, Float.MinValue) - - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getLong(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } + def collectedStatistics = Row(lower, upper, nullCount) +} - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class DoubleColumnStats extends ColumnStats { + var upper = Double.MinValue + var lower = Double.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field + if (!row.isNullAt(ordinal)) { + val value = row.getDouble(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } -} -private[sql] object IntColumnStats { - val UNINITIALIZED = 0 - val INITIALIZED = 1 - val ASCENDING = 2 - val DESCENDING = 3 - val UNORDERED = 4 + def collectedStatistics = Row(lower, upper, nullCount) } -/** - * Statistical information for `Int` columns. More information is collected since `Int` is - * frequently used. Extra information include: - * - * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search - * is applicable when searching elements. - * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression - * scheme. - * - * (This two kinds of information are not used anywhere yet and might be removed later.) - */ -private[sql] class IntColumnStats extends BasicColumnStats(INT) { - import IntColumnStats._ - - private var orderedState = UNINITIALIZED - private var lastValue: Int = _ - private var _maxDelta: Int = _ - - def isAscending = orderedState != DESCENDING && orderedState != UNORDERED - def isDescending = orderedState != ASCENDING && orderedState != UNORDERED - def isOrdered = isAscending || isDescending - def maxDelta = _maxDelta - - override def initialBounds = (Int.MaxValue, Int.MinValue) +private[sql] class FloatColumnStats extends ColumnStats { + var upper = Float.MinValue + var lower = Float.MaxValue + var nullCount = 0 - override def isBelow(row: Row, ordinal: Int) = { - lowerBound < columnType.getField(row, ordinal) + override def gatherStats(row: Row, ordinal: Int) { + if (!row.isNullAt(ordinal)) { + val value = row.getFloat(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 + } } - override def isAbove(row: Row, ordinal: Int) = { - columnType.getField(row, ordinal) < upperBound - } + def collectedStatistics = Row(lower, upper, nullCount) +} - override def contains(row: Row, ordinal: Int) = { - val field = columnType.getField(row, ordinal) - lowerBound <= field && field <= upperBound - } +private[sql] class IntColumnStats extends ColumnStats { + var upper = Int.MinValue + var lower = Int.MaxValue + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - - if (field > upperBound) _upper = field - if (field < lowerBound) _lower = field - - orderedState = orderedState match { - case UNINITIALIZED => - lastValue = field - INITIALIZED - - case INITIALIZED => - // If all the integers in the column are the same, ordered state is set to Ascending. - // TODO (lian) Confirm whether this is the standard behaviour. - val nextState = if (field >= lastValue) ASCENDING else DESCENDING - _maxDelta = math.abs(field - lastValue) - lastValue = field - nextState - - case ASCENDING if field < lastValue => - UNORDERED - - case DESCENDING if field > lastValue => - UNORDERED - - case state @ (ASCENDING | DESCENDING) => - _maxDelta = _maxDelta.max(field - lastValue) - lastValue = field - state - - case _ => - orderedState + if (!row.isNullAt(ordinal)) { + val value = row.getInt(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + } else { + nullCount += 1 } } + + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class StringColumnStats extends BasicColumnStats(STRING) { - override def initialBounds = (null, null) +private[sql] class StringColumnStats extends ColumnStats { + var upper: String = null + var lower: String = null + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field - if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field - } - - override def contains(row: Row, ordinal: Int) = { - (upperBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 - } - } - - override def isAbove(row: Row, ordinal: Int) = { - (upperBound ne null) && { - val field = columnType.getField(row, ordinal) - field.compareTo(upperBound) < 0 + if (!row.isNullAt(ordinal)) { + val value = row.getString(ordinal) + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + } else { + nullCount += 1 } } - override def isBelow(row: Row, ordinal: Int) = { - (lowerBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) < 0 - } - } + def collectedStatistics = Row(lower, upper, nullCount) } -private[sql] class TimestampColumnStats extends BasicColumnStats(TIMESTAMP) { - override def initialBounds = (null, null) +private[sql] class TimestampColumnStats extends ColumnStats { + var upper: Timestamp = null + var lower: Timestamp = null + var nullCount = 0 override def gatherStats(row: Row, ordinal: Int) { - val field = columnType.getField(row, ordinal) - if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field - if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field - } - - override def contains(row: Row, ordinal: Int) = { - (upperBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Timestamp] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + } else { + nullCount += 1 } } - override def isAbove(row: Row, ordinal: Int) = { - (lowerBound ne null) && { - val field = columnType.getField(row, ordinal) - field.compareTo(upperBound) < 0 - } - } - - override def isBelow(row: Row, ordinal: Int) = { - (lowerBound ne null) && { - val field = columnType.getField(row, ordinal) - lowerBound.compareTo(field) < 0 - } - } + def collectedStatistics = Row(lower, upper, nullCount) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index cb055cd74a5e5..6eab2f23c18e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -19,35 +19,41 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{LeafNode, SparkPlan} -object InMemoryRelation { +private[sql] object InMemoryRelation { def apply(useCompression: Boolean, batchSize: Int, child: SparkPlan): InMemoryRelation = new InMemoryRelation(child.output, useCompression, batchSize, child)() } +private[sql] case class CachedBatch(buffers: Array[ByteBuffer], stats: Row) + private[sql] case class InMemoryRelation( output: Seq[Attribute], useCompression: Boolean, batchSize: Int, child: SparkPlan) - (private var _cachedColumnBuffers: RDD[Array[ByteBuffer]] = null) + (private var _cachedColumnBuffers: RDD[CachedBatch] = null) extends LogicalPlan with MultiInstanceRelation { override lazy val statistics = Statistics(sizeInBytes = child.sqlContext.defaultSizeInBytes) + val partitionStatistics = new PartitionStatistics(output) + // If the cached column buffers were not passed in, we calculate them in the constructor. // As in Spark, the actual work of caching is lazy. if (_cachedColumnBuffers == null) { val output = child.output val cached = child.execute().mapPartitions { baseIterator => - new Iterator[Array[ByteBuffer]] { + new Iterator[CachedBatch] { def next() = { val columnBuilders = output.map { attribute => val columnType = ColumnType(attribute.dataType) @@ -68,7 +74,10 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - columnBuilders.map(_.build()) + val stats = Row.fromSeq( + columnBuilders.map(_.columnStats.collectedStatistics).foldLeft(Seq.empty[Any])(_ ++ _)) + + CachedBatch(columnBuilders.map(_.build()), stats) } def hasNext = baseIterator.hasNext @@ -79,7 +88,6 @@ private[sql] case class InMemoryRelation( _cachedColumnBuffers = cached } - override def children = Seq.empty override def newInstance() = { @@ -96,13 +104,98 @@ private[sql] case class InMemoryRelation( private[sql] case class InMemoryColumnarTableScan( attributes: Seq[Attribute], + predicates: Seq[Expression], relation: InMemoryRelation) extends LeafNode { + @transient override val sqlContext = relation.child.sqlContext + override def output: Seq[Attribute] = attributes + // Returned filter predicate should return false iff it is impossible for the input expression + // to evaluate to `true' based on statistics collected about this partition batch. + val buildFilter: PartialFunction[Expression, Expression] = { + case And(lhs: Expression, rhs: Expression) + if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => + buildFilter(lhs) && buildFilter(rhs) + + case Or(lhs: Expression, rhs: Expression) + if buildFilter.isDefinedAt(lhs) && buildFilter.isDefinedAt(rhs) => + buildFilter(lhs) || buildFilter(rhs) + + case EqualTo(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l && l <= aStats.upperBound + + case EqualTo(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l && l <= aStats.upperBound + + case LessThan(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound < l + + case LessThan(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + l < aStats.upperBound + + case LessThanOrEqual(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l + + case LessThanOrEqual(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + l <= aStats.upperBound + + case GreaterThan(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + l < aStats.upperBound + + case GreaterThan(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound < l + + case GreaterThanOrEqual(a: AttributeReference, l: Literal) => + val aStats = relation.partitionStatistics.forAttribute(a) + l <= aStats.upperBound + + case GreaterThanOrEqual(l: Literal, a: AttributeReference) => + val aStats = relation.partitionStatistics.forAttribute(a) + aStats.lowerBound <= l + } + + val partitionFilters = { + predicates.flatMap { p => + val filter = buildFilter.lift(p) + val boundFilter = + filter.map( + BindReferences.bindReference( + _, + relation.partitionStatistics.schema, + allowFailures = true)) + + boundFilter.foreach(_ => + filter.foreach(f => logInfo(s"Predicate $p generates partition filter: $f"))) + + // If the filter can't be resolved then we are missing required statistics. + boundFilter.filter(_.resolved) + } + } + + val readPartitions = sparkContext.accumulator(0) + val readBatches = sparkContext.accumulator(0) + + private val inMemoryPartitionPruningEnabled = sqlContext.inMemoryPartitionPruning + override def execute() = { + readPartitions.setValue(0) + readBatches.setValue(0) + relation.cachedColumnBuffers.mapPartitions { iterator => + val partitionFilter = newPredicate( + partitionFilters.reduceOption(And).getOrElse(Literal(true)), + relation.partitionStatistics.schema) + // Find the ordinals of the requested columns. If none are requested, use the first. val requestedColumns = if (attributes.isEmpty) { Seq(0) @@ -110,8 +203,26 @@ private[sql] case class InMemoryColumnarTableScan( attributes.map(a => relation.output.indexWhere(_.exprId == a.exprId)) } - iterator - .map(batch => requestedColumns.map(batch(_)).map(ColumnAccessor(_))) + val rows = iterator + // Skip pruned batches + .filter { cachedBatch => + if (inMemoryPartitionPruningEnabled && !partitionFilter(cachedBatch.stats)) { + def statsString = relation.partitionStatistics.schema + .zip(cachedBatch.stats) + .map { case (a, s) => s"${a.name}: $s" } + .mkString(", ") + logInfo(s"Skipping partition based on stats $statsString") + false + } else { + readBatches += 1 + true + } + } + // Build column accessors + .map { cachedBatch => + requestedColumns.map(cachedBatch.buffers(_)).map(ColumnAccessor(_)) + } + // Extract rows via column accessors .flatMap { columnAccessors => val nextRow = new GenericMutableRow(columnAccessors.length) new Iterator[Row] { @@ -127,6 +238,12 @@ private[sql] case class InMemoryColumnarTableScan( override def hasNext = columnAccessors.head.hasNext } } + + if (rows.hasNext) { + readPartitions += 1 + } + + rows } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index f631ee76fcd78..a72970eef7aa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -49,6 +49,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { } abstract override def appendFrom(row: Row, ordinal: Int) { + columnStats.gatherStats(row, ordinal) if (row.isNullAt(ordinal)) { nulls = ColumnBuilder.ensureFreeSpace(nulls, 4) nulls.putInt(pos) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 4802e40595807..927f40063e47e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -36,25 +36,23 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una override def outputPartitioning = newPartitioning - def output = child.output + override def output = child.output /** We must copy rows when sort based shuffle is on */ protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] - def execute() = attachTree(this , "execute") { + override def execute() = attachTree(this , "execute") { newPartitioning match { case HashPartitioning(expressions, numPartitions) => // TODO: Eliminate redundant expressions in grouping key and value. - val rdd = child.execute().mapPartitions { iter => - if (sortBasedShuffleOn) { - @transient val hashExpressions = - newProjection(expressions, child.output) - + val rdd = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => + val hashExpressions = newProjection(expressions, child.output) iter.map(r => (hashExpressions(r), r.copy())) - } else { - @transient val hashExpressions = - newMutableProjection(expressions, child.output)() - + } + } else { + child.execute().mapPartitions { iter => + val hashExpressions = newMutableProjection(expressions, child.output)() val mutablePair = new MutablePair[Row, Row]() iter.map(r => mutablePair.update(hashExpressions(r), r)) } @@ -65,17 +63,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.map(_._2) case RangePartitioning(sortingExpressions, numPartitions) => - // TODO: RangePartitioner should take an Ordering. - implicit val ordering = new RowOrdering(sortingExpressions, child.output) - - val rdd = child.execute().mapPartitions { iter => - if (sortBasedShuffleOn) { - iter.map(row => (row.copy(), null)) - } else { + val rdd = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null))} + } else { + child.execute().mapPartitions { iter => val mutablePair = new MutablePair[Row, Null](null, null) iter.map(row => mutablePair.update(row, null)) } } + + // TODO: RangePartitioner should take an Ordering. + implicit val ordering = new RowOrdering(sortingExpressions, child.output) + val part = new RangePartitioner(numPartitions, rdd, ascending = true) val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part) shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) @@ -83,10 +82,10 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una shuffled.map(_._1) case SinglePartition => - val rdd = child.execute().mapPartitions { iter => - if (sortBasedShuffleOn) { - iter.map(r => (null, r.copy())) - } else { + val rdd = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => iter.map(r => (null, r.copy())) } + } else { + child.execute().mapPartitions { iter => val mutablePair = new MutablePair[Null, Row]() iter.map(r => mutablePair.update(null, r)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8dacb84c8a17e..7943d6e1b6fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -243,8 +243,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { pruneFilterProject( projectList, filters, - identity[Seq[Expression]], // No filters are pushed down. - InMemoryColumnarTableScan(_, mem)) :: Nil + identity[Seq[Expression]], // All filters still need to be evaluated. + InMemoryColumnarTableScan(_, filters, mem)) :: Nil case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4abda21ffec96..47bff0c730b8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.{SparkEnv, HashPartitioner, SparkConf} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.{HashPartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} -import org.apache.spark.sql.SQLContext +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -96,6 +96,9 @@ case class Limit(limit: Int, child: SparkPlan) // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again + /** We must copy rows when sort based shuffle is on */ + private def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager] + override def output = child.output /** @@ -143,9 +146,15 @@ case class Limit(limit: Int, child: SparkPlan) } override def execute() = { - val rdd = child.execute().mapPartitions { iter => - val mutablePair = new MutablePair[Boolean, Row]() - iter.take(limit).map(row => mutablePair.update(false, row)) + val rdd: RDD[_ <: Product2[Boolean, Row]] = if (sortBasedShuffleOn) { + child.execute().mapPartitions { iter => + iter.take(limit).map(row => (false, row.copy())) + } + } else { + child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Boolean, Row]() + iter.take(limit).map(row => mutablePair.update(false, row)) + } } val part = new HashPartitioner(1) val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 031b695169cea..94543fc95b470 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,11 +21,13 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.{Row, SQLConf, SQLContext} trait Command { + this: SparkPlan => + /** * A concrete command should override this lazy field to wrap up any side effects caused by the * command or any other computation that should be evaluated exactly once. The value of this field @@ -35,7 +37,11 @@ trait Command { * The `execute()` method of all the physical command classes should reference `sideEffectResult` * so that the command can be executed eagerly right after the command query is created. */ - protected[sql] lazy val sideEffectResult: Seq[Any] = Seq.empty[Any] + protected[sql] lazy val sideEffectResult: Seq[Row] = Seq.empty[Row] + + override def executeCollect(): Array[Row] = sideEffectResult.toArray + + override def execute(): RDD[Row] = sqlContext.sparkContext.parallelize(sideEffectResult, 1) } /** @@ -47,17 +53,17 @@ case class SetCommand( @transient context: SQLContext) extends LeafNode with Command with Logging { - override protected[sql] lazy val sideEffectResult: Seq[String] = (key, value) match { + override protected[sql] lazy val sideEffectResult: Seq[Row] = (key, value) match { // Set value for key k. case (Some(k), Some(v)) => if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS} instead.") context.setConf(SQLConf.SHUFFLE_PARTITIONS, v) - Array(s"${SQLConf.SHUFFLE_PARTITIONS}=$v") + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=$v")) } else { context.setConf(k, v) - Array(s"$k=$v") + Seq(Row(s"$k=$v")) } // Query the value bound to key k. @@ -72,29 +78,31 @@ case class SetCommand( "hive-hwi-0.12.0.jar", "hive-0.12.0.jar").mkString(":") - Array( - "system:java.class.path=" + hiveJars, - "system:sun.java.command=shark.SharkServer2") - } - else { - Array(s"$k=${context.getConf(k, "")}") + context.getAllConfs.map { case (k, v) => + Row(s"$k=$v") + }.toSeq ++ Seq( + Row("system:java.class.path=" + hiveJars), + Row("system:sun.java.command=shark.SharkServer2")) + } else { + if (k == SQLConf.Deprecated.MAPRED_REDUCE_TASKS) { + logWarning(s"Property ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS} is deprecated, " + + s"showing ${SQLConf.SHUFFLE_PARTITIONS} instead.") + Seq(Row(s"${SQLConf.SHUFFLE_PARTITIONS}=${context.numShufflePartitions}")) + } else { + Seq(Row(s"$k=${context.getConf(k, "")}")) + } } // Query all key-value pairs that are set in the SQLConf of the context. case (None, None) => context.getAllConfs.map { case (k, v) => - s"$k=$v" + Row(s"$k=$v") }.toSeq case _ => throw new IllegalArgumentException() } - def execute(): RDD[Row] = { - val rows = sideEffectResult.map { line => new GenericRow(Array[Any](line)) } - context.sparkContext.parallelize(rows, 1) - } - override def otherCopyArgs = context :: Nil } @@ -113,19 +121,14 @@ case class ExplainCommand( extends LeafNode with Command { // Run through the optimizer to generate the physical plan. - override protected[sql] lazy val sideEffectResult: Seq[String] = try { + override protected[sql] lazy val sideEffectResult: Seq[Row] = try { // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. val queryExecution = context.executePlan(logicalPlan) val outputString = if (extended) queryExecution.toString else queryExecution.simpleString - outputString.split("\n") + outputString.split("\n").map(Row(_)) } catch { case cause: TreeNodeException[_] => - ("Error occurred during query planning: \n" + cause.getMessage).split("\n") - } - - def execute(): RDD[Row] = { - val explanation = sideEffectResult.map(row => new GenericRow(Array[Any](row))) - context.sparkContext.parallelize(explanation, 1) + ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) } override def otherCopyArgs = context :: Nil @@ -144,12 +147,7 @@ case class CacheCommand(tableName: String, doCache: Boolean)(@transient context: } else { context.uncacheTable(tableName) } - Seq.empty[Any] - } - - override def execute(): RDD[Row] = { - sideEffectResult - context.emptyResult + Seq.empty[Row] } override def output: Seq[Attribute] = Seq.empty @@ -163,15 +161,8 @@ case class DescribeCommand(child: SparkPlan, output: Seq[Attribute])( @transient context: SQLContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { - Seq(("# Registered as a temporary table", null, null)) ++ - child.output.map(field => (field.name, field.dataType.toString, null)) - } - - override def execute(): RDD[Row] = { - val rows = sideEffectResult.map { - case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) - } - context.sparkContext.parallelize(rows, 1) + override protected[sql] lazy val sideEffectResult: Seq[Row] = { + Row("# Registered as a temporary table", null, null) +: + child.output.map(field => Row(field.name, field.dataType.toString, null)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 8ff757bbe3508..a9535a750bcd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -74,22 +74,22 @@ package object debug { } /** - * A collection of stats for each column of output. + * A collection of metrics for each column of output. * @param elementTypes the actual runtime types for the output. Useful when there are bugs * causing the wrong data to be projected. */ - case class ColumnStat( + case class ColumnMetrics( elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) val tupleCount = sparkContext.accumulator[Int](0) val numColumns = child.output.size - val columnStats = Array.fill(child.output.size)(new ColumnStat()) + val columnStats = Array.fill(child.output.size)(new ColumnMetrics()) def dumpStats(): Unit = { println(s"== ${child.simpleString} ==") println(s"Tuples output: ${tupleCount.value}") - child.output.zip(columnStats).foreach { case(attr, stat) => - val actualDataTypes =stat.elementTypes.value.mkString("{", ",", "}") + child.output.zip(columnStats).foreach { case(attr, metric) => + val actualDataTypes = metric.elementTypes.value.mkString("{", ",", "}") println(s" ${attr.name} ${attr.dataType}: $actualDataTypes") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 3dc8be2456781..0977da3e8577c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -42,6 +42,7 @@ private[spark] case class PythonUDF( envVars: JMap[String, String], pythonIncludes: JList[String], pythonExec: String, + broadcastVars: JList[Broadcast[Array[Byte]]], accumulator: Accumulator[JList[Array[Byte]]], dataType: DataType, children: Seq[Expression]) extends Expression with SparkLogging { @@ -145,7 +146,7 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: udf.pythonIncludes, false, udf.pythonExec, - Seq[Broadcast[Array[Byte]]](), + udf.broadcastVars, udf.accumulator ).mapPartitions { iter => val pickle = new Unpickler diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 9fd6aed402838..2fc7e1cf23ab7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -382,7 +382,7 @@ private[parquet] class CatalystPrimitiveConverter( parent.updateLong(fieldIndex, value) } -object CatalystArrayConverter { +private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index fe28e0d7269e0..7c83f1cad7d71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Predicate => CatalystPredicate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer -object ParquetFilters { +private[sql] object ParquetFilters { val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter" // set this to false if pushdown should be disabled val PARQUET_FILTER_PUSHDOWN_ENABLED = "spark.sql.hints.parquetFilterPushdown" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 5f61fb5e16ea3..cde91ceb68c98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -19,29 +19,30 @@ package org.apache.spark.sql.columnar import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.types._ class ColumnStatsSuite extends FunSuite { - testColumnStats(classOf[BooleanColumnStats], BOOLEAN) - testColumnStats(classOf[ByteColumnStats], BYTE) - testColumnStats(classOf[ShortColumnStats], SHORT) - testColumnStats(classOf[IntColumnStats], INT) - testColumnStats(classOf[LongColumnStats], LONG) - testColumnStats(classOf[FloatColumnStats], FLOAT) - testColumnStats(classOf[DoubleColumnStats], DOUBLE) - testColumnStats(classOf[StringColumnStats], STRING) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP) - - def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]]( + testColumnStats(classOf[ByteColumnStats], BYTE, Row(Byte.MaxValue, Byte.MinValue, 0)) + testColumnStats(classOf[ShortColumnStats], SHORT, Row(Short.MaxValue, Short.MinValue, 0)) + testColumnStats(classOf[IntColumnStats], INT, Row(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[LongColumnStats], LONG, Row(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[FloatColumnStats], FLOAT, Row(Float.MaxValue, Float.MinValue, 0)) + testColumnStats(classOf[DoubleColumnStats], DOUBLE, Row(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, Row(null, null, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, Row(null, null, 0)) + + def testColumnStats[T <: NativeType, U <: ColumnStats]( columnStatsClass: Class[U], - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + initialStatistics: Row) { val columnStatsName = columnStatsClass.getSimpleName test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - assertResult(columnStats.initialBounds, "Wrong initial bounds") { - (columnStats.lowerBound, columnStats.upperBound) + columnStats.collectedStatistics.zip(initialStatistics).foreach { case (actual, expected) => + assert(actual === expected) } } @@ -49,14 +50,16 @@ class ColumnStatsSuite extends FunSuite { import ColumnarTestUtils._ val columnStats = columnStatsClass.newInstance() - val rows = Seq.fill(10)(makeRandomRow(columnType)) + val rows = Seq.fill(10)(makeRandomRow(columnType)) ++ Seq.fill(10)(makeNullRow(1)) rows.foreach(columnStats.gatherStats(_, 0)) - val values = rows.map(_.head.asInstanceOf[T#JvmType]) + val values = rows.take(10).map(_.head.asInstanceOf[T#JvmType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] + val stats = columnStats.collectedStatistics - assertResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound) - assertResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound) + assertResult(values.min(ordering), "Wrong lower bound")(stats(0)) + assertResult(values.max(ordering), "Wrong upper bound")(stats(1)) + assertResult(10, "Wrong null count")(stats(2)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index dc813fe146c47..a77262534a352 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkSqlSerializer class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) - extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats, columnType) with NullableColumnBuilder object TestNullableColumnBuilder { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala new file mode 100644 index 0000000000000..5d2fd4959197c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} + +import org.apache.spark.sql._ +import org.apache.spark.sql.test.TestSQLContext._ + +case class IntegerData(i: Int) + +class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter { + val originalColumnBatchSize = columnBatchSize + val originalInMemoryPartitionPruning = inMemoryPartitionPruning + + override protected def beforeAll() { + // Make a table with 5 partitions, 2 batches per partition, 10 elements per batch + setConf(SQLConf.COLUMN_BATCH_SIZE, "10") + val rawData = sparkContext.makeRDD(1 to 100, 5).map(IntegerData) + rawData.registerTempTable("intData") + + // Enable in-memory partition pruning + setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") + } + + override protected def afterAll() { + setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) + } + + before { + cacheTable("intData") + } + + after { + uncacheTable("intData") + } + + // Comparisons + checkBatchPruning("i = 1", Seq(1), 1, 1) + checkBatchPruning("1 = i", Seq(1), 1, 1) + checkBatchPruning("i < 12", 1 to 11, 1, 2) + checkBatchPruning("i <= 11", 1 to 11, 1, 2) + checkBatchPruning("i > 88", 89 to 100, 1, 2) + checkBatchPruning("i >= 89", 89 to 100, 1, 2) + checkBatchPruning("12 > i", 1 to 11, 1, 2) + checkBatchPruning("11 >= i", 1 to 11, 1, 2) + checkBatchPruning("88 < i", 89 to 100, 1, 2) + checkBatchPruning("89 <= i", 89 to 100, 1, 2) + + // Conjunction and disjunction + checkBatchPruning("i > 8 AND i <= 21", 9 to 21, 2, 3) + checkBatchPruning("i < 2 OR i > 99", Seq(1, 100), 2, 2) + checkBatchPruning("i < 2 OR (i > 78 AND i < 92)", Seq(1) ++ (79 to 91), 3, 4) + + // With unsupported predicate + checkBatchPruning("i < 12 AND i IS NOT NULL", 1 to 11, 1, 2) + checkBatchPruning("NOT (i < 88)", 88 to 100, 5, 10) + + def checkBatchPruning( + filter: String, + expectedQueryResult: Seq[Int], + expectedReadPartitions: Int, + expectedReadBatches: Int) { + + test(filter) { + val query = sql(s"SELECT * FROM intData WHERE $filter") + assertResult(expectedQueryResult.toArray, "Wrong query result") { + query.collect().map(_.head).toArray + } + + val (readPartitions, readBatches) = query.queryExecution.executedPlan.collect { + case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) + }.head + + assert(readBatches === expectedReadBatches, "Wrong number of read batches") + assert(readPartitions === expectedReadPartitions, "Wrong number of read partitions") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala index 5fba00480967c..e01cc8b4d20f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar.compression import org.scalatest.FunSuite import org.apache.spark.sql.Row -import org.apache.spark.sql.columnar.{BOOLEAN, BooleanColumnStats} +import org.apache.spark.sql.columnar.{NoopColumnStats, BOOLEAN} import org.apache.spark.sql.columnar.ColumnarTestUtils._ class BooleanBitSetSuite extends FunSuite { @@ -31,7 +31,7 @@ class BooleanBitSetSuite extends FunSuite { // Tests encoder // ------------- - val builder = TestCompressibleColumnBuilder(new BooleanColumnStats, BOOLEAN, BooleanBitSet) + val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) val values = rows.map(_.head) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index d8ae2a26778c9..d2969d906c943 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -31,7 +31,7 @@ class DictionaryEncodingSuite extends FunSuite { testDictionaryEncoding(new StringColumnStats, STRING) def testDictionaryEncoding[T <: NativeType]( - columnStats: NativeColumnStats[T], + columnStats: ColumnStats, columnType: NativeColumnType[T]) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala index 17619dcf974e3..322f447c24840 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -29,7 +29,7 @@ class IntegralDeltaSuite extends FunSuite { testIntegralDelta(new LongColumnStats, LONG, LongDelta) def testIntegralDelta[I <: IntegralType]( - columnStats: NativeColumnStats[I], + columnStats: ColumnStats, columnType: NativeColumnType[I], scheme: IntegralDelta[I]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index 40115beb98899..218c09ac26362 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.columnar._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ class RunLengthEncodingSuite extends FunSuite { - testRunLengthEncoding(new BooleanColumnStats, BOOLEAN) + testRunLengthEncoding(new NoopColumnStats, BOOLEAN) testRunLengthEncoding(new ByteColumnStats, BYTE) testRunLengthEncoding(new ShortColumnStats, SHORT) testRunLengthEncoding(new IntColumnStats, INT) @@ -32,7 +32,7 @@ class RunLengthEncodingSuite extends FunSuite { testRunLengthEncoding(new StringColumnStats, STRING) def testRunLengthEncoding[T <: NativeType]( - columnStats: NativeColumnStats[T], + columnStats: ColumnStats, columnType: NativeColumnType[T]) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala index 72c19fa31d980..7db723d648d80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ class TestCompressibleColumnBuilder[T <: NativeType]( - override val columnStats: NativeColumnStats[T], + override val columnStats: ColumnStats, override val columnType: NativeColumnType[T], override val schemes: Seq[CompressionScheme]) extends NativeColumnBuilder(columnStats, columnType) @@ -33,7 +33,7 @@ class TestCompressibleColumnBuilder[T <: NativeType]( object TestCompressibleColumnBuilder { def apply[T <: NativeType]( - columnStats: NativeColumnStats[T], + columnStats: ColumnStats, columnType: NativeColumnType[T], scheme: CompressionScheme) = { diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index c6f60c18804a4..124fc107cb8aa 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index f12b5a69a09f7..bd3f68d92d8c7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -39,7 +39,9 @@ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. */ -class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManager with Logging { +private[thriftserver] class SparkSQLOperationManager(hiveContext: HiveContext) + extends OperationManager with Logging { + val handleToOperation = ReflectionUtils .getSuperField[JMap[OperationHandle, Operation]](this, "handleToOperation") diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index b589994bd25fa..ab487d673e813 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -35,26 +35,29 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalTimeZone = TimeZone.getDefault private val originalLocale = Locale.getDefault - private val originalUseCompression = TestHive.useCompression + private val originalColumnBatchSize = TestHive.columnBatchSize + private val originalInMemoryPartitionPruning = TestHive.inMemoryPartitionPruning def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) override def beforeAll() { - // Enable in-memory columnar caching TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting Locale.setDefault(Locale.US) - // Enable in-memory columnar compression - TestHive.setConf(SQLConf.COMPRESS_CACHED, "true") + // Set a relatively small column batch size for testing purposes + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, "5") + // Enable in-memory partition pruning for testing purposes + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, "true") } override def afterAll() { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - TestHive.setConf(SQLConf.COMPRESS_CACHED, originalUseCompression.toString) + TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize.toString) + TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning.toString) } /** A list of tests deemed out of scope currently and thus completely disregarded. */ diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 30ff277e67c88..45a4c6dc98da0 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index d9b2bc7348ad2..ced8397972fbd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -389,7 +389,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType, _)) => + case (map: Map[_, _], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -409,7 +409,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { // be similar with Hive. describeHiveTableCommand.hiveString case command: PhysicalCommand => - command.sideEffectResult.map(_.toString) + command.sideEffectResult.map(_.head.toString) case other => val result: Seq[Seq[Any]] = toRdd.collect().toSeq diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index a4dd6be5f9e35..c98287c6aa662 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -44,6 +44,8 @@ private[hive] case class SourceCommand(filePath: String) extends Command private[hive] case class AddFile(filePath: String) extends Command +private[hive] case class AddJar(path: String) extends Command + private[hive] case class DropTable(tableName: String, ifExists: Boolean) extends Command private[hive] case class AnalyzeTable(tableName: String) extends Command @@ -231,7 +233,7 @@ private[hive] object HiveQl { } else if (sql.trim.toLowerCase.startsWith("uncache table")) { CacheCommand(sql.trim.drop(14).trim, false) } else if (sql.trim.toLowerCase.startsWith("add jar")) { - NativeCommand(sql) + AddJar(sql.trim.drop(8).trim) } else if (sql.trim.toLowerCase.startsWith("add file")) { AddFile(sql.trim.drop(9)) } else if (sql.trim.toLowerCase.startsWith("dfs")) { @@ -1018,9 +1020,9 @@ private[hive] object HiveQl { /* Other functions */ case Token("TOK_FUNCTION", Token(RAND(), Nil) :: Nil) => Rand - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => + case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) - case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => + case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) /* UDFs - Must be last otherwise will preempt built in functions */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 47e24f0dec146..72cc01cdf4c84 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -18,17 +18,19 @@ package org.apache.spark.sql.hive import org.apache.spark.annotation.Experimental -import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.columnar.InMemoryRelation -import org.apache.spark.sql.parquet.{ParquetRelation, ParquetTableScan} +import org.apache.spark.sql.execution.{DescribeCommand, OutputFaker, SparkPlan} +import org.apache.spark.sql.hive +import org.apache.spark.sql.hive.execution._ +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.{SQLContext, SchemaRDD} import scala.collection.JavaConversions._ @@ -193,12 +195,13 @@ private[hive] trait HiveStrategies { case class HiveCommandStrategy(context: HiveContext) extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.NativeCommand(sql) => - NativeCommand(sql, plan.output)(context) :: Nil + case logical.NativeCommand(sql) => NativeCommand(sql, plan.output)(context) :: Nil + + case hive.DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil - case DropTable(tableName, ifExists) => execution.DropTable(tableName, ifExists) :: Nil + case hive.AddJar(path) => execution.AddJar(path) :: Nil - case AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil + case hive.AnalyzeTable(tableName) => execution.AnalyzeTable(tableName) :: Nil case describe: logical.DescribeCommand => val resolvedTable = context.executePlan(describe.table).analyzed diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index a40e89e0d382b..317801001c7a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericRow, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Row} import org.apache.spark.sql.execution.{Command, LeafNode} import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} @@ -41,26 +41,21 @@ case class DescribeHiveTableCommand( extends LeafNode with Command { // Strings with the format like Hive. It is used for result comparison in our unit tests. - lazy val hiveString: Seq[String] = { - val alignment = 20 - val delim = "\t" - - sideEffectResult.map { - case (name, dataType, comment) => - String.format("%-" + alignment + "s", name) + delim + - String.format("%-" + alignment + "s", dataType) + delim + - String.format("%-" + alignment + "s", Option(comment).getOrElse("None")) - } + lazy val hiveString: Seq[String] = sideEffectResult.map { + case Row(name: String, dataType: String, comment) => + Seq(name, dataType, Option(comment.asInstanceOf[String]).getOrElse("None")) + .map(s => String.format(s"%-20s", s)) + .mkString("\t") } - override protected[sql] lazy val sideEffectResult: Seq[(String, String, String)] = { + override protected[sql] lazy val sideEffectResult: Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil val columns: Seq[FieldSchema] = table.hiveQlTable.getCols val partitionColumns: Seq[FieldSchema] = table.hiveQlTable.getPartCols results ++= columns.map(field => (field.getName, field.getType, field.getComment)) - if (!partitionColumns.isEmpty) { + if (partitionColumns.nonEmpty) { val partColumnInfo = partitionColumns.map(field => (field.getName, field.getType, field.getComment)) results ++= @@ -74,14 +69,9 @@ case class DescribeHiveTableCommand( results ++= Seq(("Detailed Table Information", table.hiveQlTable.getTTable.toString, "")) } - results - } - - override def execute(): RDD[Row] = { - val rows = sideEffectResult.map { - case (name, dataType, comment) => new GenericRow(Array[Any](name, dataType, comment)) + results.map { case (name, dataType, comment) => + Row(name, dataType, comment) } - context.sparkContext.parallelize(rows, 1) } override def otherCopyArgs = context :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala index fe6031678f70f..8f10e1ba7f426 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/NativeCommand.scala @@ -32,16 +32,7 @@ case class NativeCommand( @transient context: HiveContext) extends LeafNode with Command { - override protected[sql] lazy val sideEffectResult: Seq[String] = context.runSqlHive(sql) - - override def execute(): RDD[Row] = { - if (sideEffectResult.size == 0) { - context.emptyResult - } else { - val rows = sideEffectResult.map(r => new GenericRow(Array[Any](r))) - context.sparkContext.parallelize(rows, 1) - } - } + override protected[sql] lazy val sideEffectResult: Seq[Row] = context.runSqlHive(sql).map(Row(_)) override def otherCopyArgs = context :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 2985169da033c..d61c5e274a596 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -33,19 +33,13 @@ import org.apache.spark.sql.hive.HiveContext */ @DeveloperApi case class AnalyzeTable(tableName: String) extends LeafNode with Command { - def hiveContext = sqlContext.asInstanceOf[HiveContext] def output = Seq.empty - override protected[sql] lazy val sideEffectResult = { + override protected[sql] lazy val sideEffectResult: Seq[Row] = { hiveContext.analyze(tableName) - Seq.empty[Any] - } - - override def execute(): RDD[Row] = { - sideEffectResult - sparkContext.emptyRDD[Row] + Seq.empty[Row] } } @@ -55,20 +49,30 @@ case class AnalyzeTable(tableName: String) extends LeafNode with Command { */ @DeveloperApi case class DropTable(tableName: String, ifExists: Boolean) extends LeafNode with Command { - def hiveContext = sqlContext.asInstanceOf[HiveContext] def output = Seq.empty - override protected[sql] lazy val sideEffectResult: Seq[Any] = { + override protected[sql] lazy val sideEffectResult: Seq[Row] = { val ifExistsClause = if (ifExists) "IF EXISTS " else "" hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") hiveContext.catalog.unregisterTable(None, tableName) - Seq.empty + Seq.empty[Row] } +} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class AddJar(path: String) extends LeafNode with Command { + def hiveContext = sqlContext.asInstanceOf[HiveContext] + + override def output = Seq.empty - override def execute(): RDD[Row] = { - sideEffectResult - sparkContext.emptyRDD[Row] + override protected[sql] lazy val sideEffectResult: Seq[Row] = { + hiveContext.runSqlHive(s"ADD JAR $path") + hiveContext.sparkContext.addJar(path) + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala index 544abfc32423c..abed299cd957f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector import org.apache.hadoop.io.Writable /** - * A placeholder that allows SparkSQL users to create metastore tables that are stored as + * A placeholder that allows Spark SQL users to create metastore tables that are stored as * parquet files. It is only intended to pass the checks that the serde is valid and exists * when a CREATE TABLE is run. The actual work of decoding will be done by ParquetTableScan * when "spark.sql.hive.convertMetastoreParquet" is set to true. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index c4abb3eb4861f..f4217a52c3822 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -17,9 +17,11 @@ package org.apache.spark.sql.hive.execution +import java.io.File + import scala.util.Try -import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.SparkException import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -313,7 +315,7 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15") test("case sensitivity: registered table") { - val testData: SchemaRDD = + val testData = TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) @@ -467,7 +469,7 @@ class HiveQuerySuite extends HiveComparisonTest { } // Describe a registered temporary table. - val testData: SchemaRDD = + val testData = TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) @@ -495,6 +497,23 @@ class HiveQuerySuite extends HiveComparisonTest { } } + test("ADD JAR command") { + val testJar = TestHive.getHiveFile("data/files/TestSerDe.jar").getCanonicalPath + sql("CREATE TABLE alter1(a INT, b INT)") + intercept[Exception] { + sql( + """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' + |WITH serdeproperties('s1'='9') + """.stripMargin) + } + sql(s"ADD JAR $testJar") + sql( + """ALTER TABLE alter1 SET SERDE 'org.apache.hadoop.hive.serde2.TestSerDe' + |WITH serdeproperties('s1'='9') + """.stripMargin) + sql("DROP TABLE alter1") + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly" diff --git a/streaming/pom.xml b/streaming/pom.xml index ce35520a28609..12f900c91eb98 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 97abb6b2b63e0..f36674476770c 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/yarn/alpha/pom.xml b/yarn/alpha/pom.xml index 51744ece0412d..7dadbba58fd82 100644 --- a/yarn/alpha/pom.xml +++ b/yarn/alpha/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 7dae248e3e7db..10cbeb8b94325 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{SecurityManager, SparkConf, Logging} class ExecutorRunnable( @@ -46,7 +46,8 @@ class ExecutorRunnable( slaveId: String, hostname: String, executorMemory: Int, - executorCores: Int) + executorCores: Int, + securityMgr: SecurityManager) extends Runnable with ExecutorRunnableUtil with Logging { var rpc: YarnRPC = YarnRPC.create(conf) @@ -86,6 +87,8 @@ class ExecutorRunnable( logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands) + ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // Send the start request to the ContainerManager val startReq = Records.newRecord(classOf[StartContainerRequest]) .asInstanceOf[StartContainerRequest] diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 629cd13f67145..5a1b42c1e17d5 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -17,35 +17,21 @@ package org.apache.spark.deploy.yarn -import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} +import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.spark.{Logging, SparkConf, SparkEnv} -import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.scheduler.SplitInfo import org.apache.hadoop.conf.Configuration import org.apache.hadoop.yarn.api.AMRMProtocol -import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId} -import org.apache.hadoop.yarn.api.records.{Container, ContainerId} -import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest} -import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest import org.apache.hadoop.yarn.util.Records -// TODO: -// Too many params. -// Needs to be mt-safe -// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should -// make it more proactive and decoupled. - -// Note that right now, we assume all node asks as uniform in terms of capabilities and priority -// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for -// more info on how we are requesting for containers. - /** * Acquires resources for executors from a ResourceManager and launches executors in new containers. */ @@ -55,358 +41,23 @@ private[yarn] class YarnAllocationHandler( resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId, args: ApplicationMasterArguments, - preferredNodes: collection.Map[String, collection.Set[SplitInfo]]) - extends YarnAllocator with Logging { - - // These three are locked on allocatedHostToContainersMap. Complementary data structures - // allocatedHostToContainersMap : containers which are running : host, Set - // allocatedContainerToHostMap: container to host mapping. - private val allocatedHostToContainersMap = - new HashMap[String, collection.mutable.Set[ContainerId]]() - - private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() - - // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an - // allocated node) - // As with the two data structures above, tightly coupled with them, and to be locked on - // allocatedHostToContainersMap - private val allocatedRackCount = new HashMap[String, Int]() - - // Containers which have been released. - private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() - // Containers to be released in next request to RM - private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] - - // Additional memory overhead - in mb. - private def memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - - private val numExecutorsRunning = new AtomicInteger() - // Used to generate a unique id per executor - private val executorIdCounter = new AtomicInteger() - private val lastResponseId = new AtomicInteger() - private val numExecutorsFailed = new AtomicInteger() - - private val maxExecutors = args.numExecutors - private val executorMemory = args.executorMemory - private val executorCores = args.executorCores - private val (preferredHostToCount, preferredRackToCount) = - generateNodeToWeight(conf, preferredNodes) - - def getNumExecutorsRunning: Int = numExecutorsRunning.intValue - - def getNumExecutorsFailed: Int = numExecutorsFailed.intValue - - def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + memoryOverhead) - } - - override def allocateResources() = { - // We need to send the request only once from what I understand ... but for now, not modifying - // this much. - val executorsToRequest = Math.max(maxExecutors - numExecutorsRunning.get(), 0) - - // Keep polling the Resource Manager for containers - val amResp = allocateExecutorResources(executorsToRequest).getAMResponse - - val _allocatedContainers = amResp.getAllocatedContainers() - - if (_allocatedContainers.size > 0) { - logDebug(""" - Allocated containers: %d - Current executor count: %d - Containers released: %s - Containers to be released: %s - Cluster resources: %s - """.format( - _allocatedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers, - amResp.getAvailableResources)) - - val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() - - // Ignore if not satisfying constraints { - for (container <- _allocatedContainers) { - if (isResourceConstraintSatisfied(container)) { - // allocatedContainers += container - - val host = container.getNodeId.getHost - val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]()) - - containers += container - } else { - // Add all ignored containers to released list - releasedContainerList.add(container.getId()) - } - } - - // Find the appropriate containers to use. Slightly non trivial groupBy ... - val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (candidateHost <- hostToContainers.keySet) - { - val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) - val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) - - var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null) - assert(remainingContainers != null) - - if (requiredHostCount >= remainingContainers.size){ - // Since we got <= required containers, add all to dataLocalContainers - dataLocalContainers.put(candidateHost, remainingContainers) - // all consumed - remainingContainers = null - } else if (requiredHostCount > 0) { - // Container list has more containers than we need for data locality. - // Split into two : data local container count of (remainingContainers.size - - // requiredHostCount) and rest as remainingContainer - val (dataLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredHostCount) - dataLocalContainers.put(candidateHost, dataLocal) - // remainingContainers = remaining - - // yarn has nasty habit of allocating a tonne of containers on a host - discourage this : - // add remaining to release list. If we have insufficient containers, next allocation - // cycle will reallocate (but wont treat it as data local) - for (container <- remaining) releasedContainerList.add(container.getId()) - remainingContainers = null - } - - // Now rack local - if (remainingContainers != null){ - val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) - - if (rack != null){ - val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) - val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - - rackLocalContainers.get(rack).getOrElse(List()).size - - - if (requiredRackCount >= remainingContainers.size){ - // Add all to dataLocalContainers - dataLocalContainers.put(rack, remainingContainers) - // All consumed - remainingContainers = null - } else if (requiredRackCount > 0) { - // container list has more containers than we need for data locality. - // Split into two : data local container count of (remainingContainers.size - - // requiredRackCount) and rest as remainingContainer - val (rackLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredRackCount) - val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, - new ArrayBuffer[Container]()) - - existingRackLocal ++= rackLocal - remainingContainers = remaining - } - } - } - - // If still not consumed, then it is off rack host - add to that list. - if (remainingContainers != null){ - offRackContainers.put(candidateHost, remainingContainers) - } - } - - // Now that we have split the containers into various groups, go through them in order : - // first host local, then rack local and then off rack (everything else). - // Note that the list we create below tries to ensure that not all containers end up within a - // host if there are sufficiently large number of hosts/containers. - - val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size) - allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) - allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) - allocatedContainers ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) - - // Run each of the allocated containers - for (container <- allocatedContainers) { - val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() - val executorHostname = container.getNodeId.getHost - val containerId = container.getId - - assert( container.getResource.getMemory >= - (executorMemory + memoryOverhead)) - - if (numExecutorsRunningNow > maxExecutors) { - logInfo("""Ignoring container %s at host %s, since we already have the required number of - containers for it.""".format(containerId, executorHostname)) - releasedContainerList.add(containerId) - // reset counter back to old value. - numExecutorsRunning.decrementAndGet() - } else { - // Deallocate + allocate can result in reusing id's wrongly - so use a different counter - // (executorIdCounter) - val executorId = executorIdCounter.incrementAndGet().toString - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - - logInfo("launching container on " + containerId + " host " + executorHostname) - // Just to be safe, simply remove it from pendingReleaseContainers. - // Should not be there, but .. - pendingReleaseContainers.remove(containerId) - - val rack = YarnSparkHadoopUtil.lookupRack(conf, executorHostname) - allocatedHostToContainersMap.synchronized { - val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, - new HashSet[ContainerId]()) - - containerSet += containerId - allocatedContainerToHostMap.put(containerId, executorHostname) - if (rack != null) { - allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) - } - } - - new Thread( - new ExecutorRunnable(container, conf, sparkConf, driverUrl, executorId, - executorHostname, executorMemory, executorCores) - ).start() - } - } - logDebug(""" - Finished processing %d containers. - Current number of executors running: %d, - releasedContainerList: %s, - pendingReleaseContainers: %s - """.format( - allocatedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers)) - } - - - val completedContainers = amResp.getCompletedContainersStatuses() - if (completedContainers.size > 0){ - logDebug("Completed %d containers, to-be-released: %s".format( - completedContainers.size, releasedContainerList)) - for (completedContainer <- completedContainers){ - val containerId = completedContainer.getContainerId - - // Was this released by us ? If yes, then simply remove from containerSet and move on. - if (pendingReleaseContainers.containsKey(containerId)) { - pendingReleaseContainers.remove(containerId) - } else { - // Simply decrement count - next iteration of ReporterThread will take care of allocating. - numExecutorsRunning.decrementAndGet() - logInfo("Completed container %s (state: %s, exit status: %s)".format( - containerId, - completedContainer.getState, - completedContainer.getExitStatus())) - // Hadoop 2.2.X added a ContainerExitStatus we should switch to use - // there are some exit status' we shouldn't necessarily count against us, but for - // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus() != 0) { - logInfo("Container marked as failed: " + containerId) - numExecutorsFailed.incrementAndGet() - } - } - - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).getOrElse(null) - assert (host != null) - - val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null) - assert (containerSet != null) - - containerSet -= containerId - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap -= containerId - - // Doing this within locked context, sigh ... move to outside ? - val rack = YarnSparkHadoopUtil.lookupRack(conf, host) - if (rack != null) { - val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 - if (rackCount > 0) { - allocatedRackCount.put(rack, rackCount) - } else { - allocatedRackCount.remove(rack) - } - } - } - } - } - logDebug(""" - Finished processing %d completed containers. - Current number of executors running: %d, - releasedContainerList: %s, - pendingReleaseContainers: %s - """.format( - completedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers)) - } - } - - def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = { - // First generate modified racks and new set of hosts under it : then issue requests - val rackToCounts = new HashMap[String, Int]() - - // Within this lock - used to read/write to the rack related maps too. - for (container <- hostContainers) { - val candidateHost = container.getHostName - val candidateNumContainers = container.getNumContainers - assert(YarnSparkHadoopUtil.ANY_HOST != candidateHost) - - val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) - if (rack != null) { - var count = rackToCounts.getOrElse(rack, 0) - count += candidateNumContainers - rackToCounts.put(rack, count) - } - } - - val requestedContainers: ArrayBuffer[ResourceRequest] = - new ArrayBuffer[ResourceRequest](rackToCounts.size) - for ((rack, count) <- rackToCounts){ - requestedContainers += - createResourceRequest(AllocationType.RACK, rack, count, - YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) - } - - requestedContainers.toList - } - - def allocatedContainersOnHost(host: String): Int = { - var retval = 0 - allocatedHostToContainersMap.synchronized { - retval = allocatedHostToContainersMap.getOrElse(host, Set()).size - } - retval - } + preferredNodes: collection.Map[String, collection.Set[SplitInfo]], + securityMgr: SecurityManager) + extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { - def allocatedContainersOnRack(rack: String): Int = { - var retval = 0 - allocatedHostToContainersMap.synchronized { - retval = allocatedRackCount.getOrElse(rack, 0) - } - retval - } - - private def allocateExecutorResources(numExecutors: Int): AllocateResponse = { + private val lastResponseId = new AtomicInteger() + private val releaseList: CopyOnWriteArrayList[ContainerId] = new CopyOnWriteArrayList() + override protected def allocateContainers(count: Int): YarnAllocateResponse = { var resourceRequests: List[ResourceRequest] = null - // default. - if (numExecutors <= 0 || preferredHostToCount.isEmpty) { - logDebug("numExecutors: " + numExecutors + ", host preferences: " + - preferredHostToCount.isEmpty) - resourceRequests = List(createResourceRequest( - AllocationType.ANY, null, numExecutors, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)) + logDebug("numExecutors: " + count) + if (count <= 0) { + resourceRequests = List() + } else if (preferredHostToCount.isEmpty) { + logDebug("host preferences is empty") + resourceRequests = List(createResourceRequest( + AllocationType.ANY, null, count, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY)) } else { // request for all hosts in preferred nodes and for numExecutors - // candidates.size, request by default allocation policy. @@ -429,7 +80,7 @@ private[yarn] class YarnAllocationHandler( val anyContainerRequests: ResourceRequest = createResourceRequest( AllocationType.ANY, resource = null, - numExecutors, + count, YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) val containerRequests: ArrayBuffer[ResourceRequest] = new ArrayBuffer[ResourceRequest]( @@ -451,8 +102,8 @@ private[yarn] class YarnAllocationHandler( val releasedContainerList = createReleasedContainerList() req.addAllReleases(releasedContainerList) - if (numExecutors > 0) { - logInfo("Allocating %d executor containers with %d of memory each.".format(numExecutors, + if (count > 0) { + logInfo("Allocating %d executor containers with %d of memory each.".format(count, executorMemory + memoryOverhead)) } else { logDebug("Empty allocation req .. release : " + releasedContainerList) @@ -466,9 +117,42 @@ private[yarn] class YarnAllocationHandler( request.getPriority, request.getCapability)) } - resourceManager.allocate(req) + new AlphaAllocateResponse(resourceManager.allocate(req).getAMResponse()) } + override protected def releaseContainer(container: Container) = { + releaseList.add(container.getId()) + } + + private def createRackResourceRequests(hostContainers: List[ResourceRequest]): + List[ResourceRequest] = { + // First generate modified racks and new set of hosts under it : then issue requests + val rackToCounts = new HashMap[String, Int]() + + // Within this lock - used to read/write to the rack related maps too. + for (container <- hostContainers) { + val candidateHost = container.getHostName + val candidateNumContainers = container.getNumContainers + assert(YarnSparkHadoopUtil.ANY_HOST != candidateHost) + + val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) + if (rack != null) { + var count = rackToCounts.getOrElse(rack, 0) + count += candidateNumContainers + rackToCounts.put(rack, count) + } + } + + val requestedContainers: ArrayBuffer[ResourceRequest] = + new ArrayBuffer[ResourceRequest](rackToCounts.size) + for ((rack, count) <- rackToCounts){ + requestedContainers += + createResourceRequest(AllocationType.RACK, rack, count, + YarnSparkHadoopUtil.RM_REQUEST_PRIORITY) + } + + requestedContainers.toList + } private def createResourceRequest( requestType: AllocationType.AllocationType, @@ -521,48 +205,24 @@ private[yarn] class YarnAllocationHandler( rsrcRequest } - def createReleasedContainerList(): ArrayBuffer[ContainerId] = { - + private def createReleasedContainerList(): ArrayBuffer[ContainerId] = { val retval = new ArrayBuffer[ContainerId](1) // Iterator on COW list ... - for (container <- releasedContainerList.iterator()){ + for (container <- releaseList.iterator()){ retval += container } // Remove from the original list. - if (! retval.isEmpty) { - releasedContainerList.removeAll(retval) - for (v <- retval) pendingReleaseContainers.put(v, true) - logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " + - pendingReleaseContainers) + if (!retval.isEmpty) { + releaseList.removeAll(retval) + logInfo("Releasing " + retval.size + " containers.") } - retval } - // A simple method to copy the split info map. - private def generateNodeToWeight( - conf: Configuration, - input: collection.Map[String, collection.Set[SplitInfo]]) : - // host to count, rack to count - (Map[String, Int], Map[String, Int]) = { - - if (input == null) return (Map[String, Int](), Map[String, Int]()) - - val hostToCount = new HashMap[String, Int] - val rackToCount = new HashMap[String, Int] - - for ((host, splits) <- input) { - val hostCount = hostToCount.getOrElse(host, 0) - hostToCount.put(host, hostCount + splits.size) - - val rack = YarnSparkHadoopUtil.lookupRack(conf, host) - if (rack != null){ - val rackCount = rackToCount.getOrElse(host, 0) - rackToCount.put(host, rackCount + splits.size) - } - } - - (hostToCount.toMap, rackToCount.toMap) + private class AlphaAllocateResponse(response: AMResponse) extends YarnAllocateResponse { + override def getAllocatedContainers() = response.getAllocatedContainers() + override def getAvailableResources() = response.getAvailableResources() + override def getCompletedContainersStatuses() = response.getCompletedContainersStatuses() } } diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index cc5392192ec51..ad27a9ab781d2 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.Utils @@ -45,7 +45,8 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], uiAddress: String, - uiHistoryAddress: String) = { + uiHistoryAddress: String, + securityMgr: SecurityManager) = { this.rpc = YarnRPC.create(conf) this.uiHistoryAddress = uiHistoryAddress @@ -53,7 +54,7 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC registerApplicationMaster(uiAddress) new YarnAllocationHandler(conf, sparkConf, resourceManager, getAttemptId(), args, - preferredNodeLocations) + preferredNodeLocations, securityMgr) } override def getAttemptId() = { diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 8c548409719da..a879c833a014f 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.AddWebUIFilter import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} @@ -70,6 +71,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private val sparkContextRef = new AtomicReference[SparkContext](null) final def run(): Int = { + val appAttemptId = client.getAttemptId() + if (isDriver) { // Set the web ui port to be ephemeral for yarn so we don't conflict with // other spark processes running on the same box @@ -77,9 +80,12 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // Set the master property to match the requested mode. System.setProperty("spark.master", "yarn-cluster") + + // Propagate the application ID so that YarnClusterSchedulerBackend can pick it up. + System.setProperty("spark.yarn.app.id", appAttemptId.getApplicationId().toString()) } - logInfo("ApplicationAttemptId: " + client.getAttemptId()) + logInfo("ApplicationAttemptId: " + appAttemptId) val cleanupHook = new Runnable { override def run() { @@ -110,7 +116,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, val securityMgr = new SecurityManager(sparkConf) if (isDriver) { - runDriver() + runDriver(securityMgr) } else { runExecutorLauncher(securityMgr) } @@ -151,19 +157,27 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, sparkContextRef.compareAndSet(sc, null) } - private def registerAM(uiAddress: String, uiHistoryAddress: String) = { + private def registerAM(uiAddress: String, securityMgr: SecurityManager) = { val sc = sparkContextRef.get() + + val appId = client.getAttemptId().getApplicationId().toString() + val historyAddress = + sparkConf.getOption("spark.yarn.historyServer.address") + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}" } + .getOrElse("") + allocator = client.register(yarnConf, if (sc != null) sc.getConf else sparkConf, if (sc != null) sc.preferredNodeLocationData else Map(), uiAddress, - uiHistoryAddress) + historyAddress, + securityMgr) allocator.allocateResources() reporterThread = launchReporterThread() } - private def runDriver(): Unit = { + private def runDriver(securityMgr: SecurityManager): Unit = { addAmIpFilter() val userThread = startUserClass() @@ -175,7 +189,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, if (sc == null) { finish(FinalApplicationStatus.FAILED, "Timed out waiting for SparkContext.") } else { - registerAM(sc.ui.appUIHostPort, YarnSparkHadoopUtil.getUIHistoryAddress(sc, sparkConf)) + registerAM(sc.ui.appUIHostPort, securityMgr) try { userThread.join() } finally { @@ -190,8 +204,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, conf = sparkConf, securityManager = securityMgr)._1 actor = waitForSparkDriver() addAmIpFilter() - registerAM(sparkConf.get("spark.driver.appUIAddress", ""), - sparkConf.get("spark.driver.appUIHistoryAddress", "")) + registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) // In client mode the actor will stop the reporter thread. reporterThread.join() diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 5d8e5e6dffe7f..8075b7a7fb837 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -430,10 +430,8 @@ trait ClientBase extends Logging { // send the acl settings into YARN to control who has access via YARN interfaces val securityManager = new SecurityManager(sparkConf) - val acls = Map[ApplicationAccessType, String] ( - ApplicationAccessType.VIEW_APP -> securityManager.getViewAcls, - ApplicationAccessType.MODIFY_APP -> securityManager.getModifyAcls) - amContainer.setApplicationACLs(acls) + amContainer.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityManager)) + amContainer } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index cad94e5e19e1f..02b9a81bf6b50 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -17,18 +17,433 @@ package org.apache.spark.deploy.yarn +import java.util.{List => JList} +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse + +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} +import org.apache.spark.scheduler.{SplitInfo, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + object AllocationType extends Enumeration { type AllocationType = Value val HOST, RACK, ANY = Value } +// TODO: +// Too many params. +// Needs to be mt-safe +// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should +// make it more proactive and decoupled. + +// Note that right now, we assume all node asks as uniform in terms of capabilities and priority +// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for +// more info on how we are requesting for containers. + /** - * Interface that defines a Yarn allocator. + * Common code for the Yarn container allocator. Contains all the version-agnostic code to + * manage container allocation for a running Spark application. */ -trait YarnAllocator { +private[yarn] abstract class YarnAllocator( + conf: Configuration, + sparkConf: SparkConf, + args: ApplicationMasterArguments, + preferredNodes: collection.Map[String, collection.Set[SplitInfo]], + securityMgr: SecurityManager) + extends Logging { + + // These three are locked on allocatedHostToContainersMap. Complementary data structures + // allocatedHostToContainersMap : containers which are running : host, Set + // allocatedContainerToHostMap: container to host mapping. + private val allocatedHostToContainersMap = + new HashMap[String, collection.mutable.Set[ContainerId]]() + + private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() + + // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an + // allocated node) + // As with the two data structures above, tightly coupled with them, and to be locked on + // allocatedHostToContainersMap + private val allocatedRackCount = new HashMap[String, Int]() + + // Containers to be released in next request to RM + private val releasedContainers = new ConcurrentHashMap[ContainerId, Boolean] + + // Additional memory overhead - in mb. + protected val memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", + YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) + + // Number of container requests that have been sent to, but not yet allocated by the + // ApplicationMaster. + private val numPendingAllocate = new AtomicInteger() + private val numExecutorsRunning = new AtomicInteger() + // Used to generate a unique id per executor + private val executorIdCounter = new AtomicInteger() + private val numExecutorsFailed = new AtomicInteger() + + private val maxExecutors = args.numExecutors + + protected val executorMemory = args.executorMemory + protected val executorCores = args.executorCores + protected val (preferredHostToCount, preferredRackToCount) = + generateNodeToWeight(conf, preferredNodes) + + def getNumExecutorsRunning: Int = numExecutorsRunning.intValue + + def getNumExecutorsFailed: Int = numExecutorsFailed.intValue + + def allocateResources() = { + val missing = maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get() + + if (missing > 0) { + numPendingAllocate.addAndGet(missing) + logInfo("Will Allocate %d executor containers, each with %d memory".format( + missing, + (executorMemory + memoryOverhead))) + } else { + logDebug("Empty allocation request ...") + } + + val allocateResponse = allocateContainers(missing) + val allocatedContainers = allocateResponse.getAllocatedContainers() + + if (allocatedContainers.size > 0) { + var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size) + + if (numPendingAllocateNow < 0) { + numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow) + } + + logDebug(""" + Allocated containers: %d + Current executor count: %d + Containers released: %s + Cluster resources: %s + """.format( + allocatedContainers.size, + numExecutorsRunning.get(), + releasedContainers, + allocateResponse.getAvailableResources)) + + val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (container <- allocatedContainers) { + if (isResourceConstraintSatisfied(container)) { + // Add the accepted `container` to the host's list of already accepted, + // allocated containers + val host = container.getNodeId.getHost + val containersForHost = hostToContainers.getOrElseUpdate(host, + new ArrayBuffer[Container]()) + containersForHost += container + } else { + // Release container, since it doesn't satisfy resource constraints. + internalReleaseContainer(container) + } + } + + // Find the appropriate containers to use. + // TODO: Cleanup this group-by... + val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() + val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() + + for (candidateHost <- hostToContainers.keySet) { + val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) + val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) + + val remainingContainersOpt = hostToContainers.get(candidateHost) + assert(remainingContainersOpt.isDefined) + var remainingContainers = remainingContainersOpt.get + + if (requiredHostCount >= remainingContainers.size) { + // Since we have <= required containers, add all remaining containers to + // `dataLocalContainers`. + dataLocalContainers.put(candidateHost, remainingContainers) + // There are no more free containers remaining. + remainingContainers = null + } else if (requiredHostCount > 0) { + // Container list has more containers than we need for data locality. + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. + val (dataLocal, remaining) = remainingContainers.splitAt( + remainingContainers.size - requiredHostCount) + dataLocalContainers.put(candidateHost, dataLocal) + + // Invariant: remainingContainers == remaining + + // YARN has a nasty habit of allocating a ton of containers on a host - discourage this. + // Add each container in `remaining` to list of containers to release. If we have an + // insufficient number of containers, then the next allocation cycle will reallocate + // (but won't treat it as data local). + // TODO(harvey): Rephrase this comment some more. + for (container <- remaining) internalReleaseContainer(container) + remainingContainers = null + } + + // For rack local containers + if (remainingContainers != null) { + val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) + if (rack != null) { + val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) + val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - + rackLocalContainers.getOrElse(rack, List()).size + + if (requiredRackCount >= remainingContainers.size) { + // Add all remaining containers to to `dataLocalContainers`. + dataLocalContainers.put(rack, remainingContainers) + remainingContainers = null + } else if (requiredRackCount > 0) { + // Container list has more containers that we need for data locality. + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. + val (rackLocal, remaining) = remainingContainers.splitAt( + remainingContainers.size - requiredRackCount) + val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, + new ArrayBuffer[Container]()) + + existingRackLocal ++= rackLocal + + remainingContainers = remaining + } + } + } + + if (remainingContainers != null) { + // Not all containers have been consumed - add them to the list of off-rack containers. + offRackContainers.put(candidateHost, remainingContainers) + } + } + + // Now that we have split the containers into various groups, go through them in order: + // first host-local, then rack-local, and finally off-rack. + // Note that the list we create below tries to ensure that not all containers end up within + // a host if there is a sufficiently large number of hosts/containers. + val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) + allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) + + // Run each of the allocated containers. + for (container <- allocatedContainersToProcess) { + val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() + val executorHostname = container.getNodeId.getHost + val containerId = container.getId + + val executorMemoryOverhead = (executorMemory + memoryOverhead) + assert(container.getResource.getMemory >= executorMemoryOverhead) + + if (numExecutorsRunningNow > maxExecutors) { + logInfo("""Ignoring container %s at host %s, since we already have the required number of + containers for it.""".format(containerId, executorHostname)) + internalReleaseContainer(container) + numExecutorsRunning.decrementAndGet() + } else { + val executorId = executorIdCounter.incrementAndGet().toString + val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + SparkEnv.driverActorSystemName, + sparkConf.get("spark.driver.host"), + sparkConf.get("spark.driver.port"), + CoarseGrainedSchedulerBackend.ACTOR_NAME) + + logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) + + // To be safe, remove the container from `releasedContainers`. + releasedContainers.remove(containerId) + + val rack = YarnSparkHadoopUtil.lookupRack(conf, executorHostname) + allocatedHostToContainersMap.synchronized { + val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, + new HashSet[ContainerId]()) + + containerSet += containerId + allocatedContainerToHostMap.put(containerId, executorHostname) + + if (rack != null) { + allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) + } + } + logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( + driverUrl, executorHostname)) + val executorRunnable = new ExecutorRunnable( + container, + conf, + sparkConf, + driverUrl, + executorId, + executorHostname, + executorMemory, + executorCores, + securityMgr) + new Thread(executorRunnable).start() + } + } + logDebug(""" + Finished allocating %s containers (from %s originally). + Current number of executors running: %d, + Released containers: %s + """.format( + allocatedContainersToProcess, + allocatedContainers, + numExecutorsRunning.get(), + releasedContainers)) + } + + val completedContainers = allocateResponse.getCompletedContainersStatuses() + if (completedContainers.size > 0) { + logDebug("Completed %d containers".format(completedContainers.size)) + + for (completedContainer <- completedContainers) { + val containerId = completedContainer.getContainerId + + if (releasedContainers.containsKey(containerId)) { + // YarnAllocationHandler already marked the container for release, so remove it from + // `releasedContainers`. + releasedContainers.remove(containerId) + } else { + // Decrement the number of executors running. The next iteration of + // the ApplicationMaster's reporting thread will take care of allocating. + numExecutorsRunning.decrementAndGet() + logInfo("Completed container %s (state: %s, exit status: %s)".format( + containerId, + completedContainer.getState, + completedContainer.getExitStatus())) + // Hadoop 2.2.X added a ContainerExitStatus we should switch to use + // there are some exit status' we shouldn't necessarily count against us, but for + // now I think its ok as none of the containers are expected to exit + if (completedContainer.getExitStatus() != 0) { + logInfo("Container marked as failed: " + containerId) + numExecutorsFailed.incrementAndGet() + } + } + + allocatedHostToContainersMap.synchronized { + if (allocatedContainerToHostMap.containsKey(containerId)) { + val hostOpt = allocatedContainerToHostMap.get(containerId) + assert(hostOpt.isDefined) + val host = hostOpt.get + + val containerSetOpt = allocatedHostToContainersMap.get(host) + assert(containerSetOpt.isDefined) + val containerSet = containerSetOpt.get + + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) + } + + allocatedContainerToHostMap.remove(containerId) + + // TODO: Move this part outside the synchronized block? + val rack = YarnSparkHadoopUtil.lookupRack(conf, host) + if (rack != null) { + val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 + if (rackCount > 0) { + allocatedRackCount.put(rack, rackCount) + } else { + allocatedRackCount.remove(rack) + } + } + } + } + } + logDebug(""" + Finished processing %d completed containers. + Current number of executors running: %d, + Released containers: %s + """.format( + completedContainers.size, + numExecutorsRunning.get(), + releasedContainers)) + } + } + + protected def allocatedContainersOnHost(host: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedHostToContainersMap.getOrElse(host, Set()).size + } + retval + } + + protected def allocatedContainersOnRack(rack: String): Int = { + var retval = 0 + allocatedHostToContainersMap.synchronized { + retval = allocatedRackCount.getOrElse(rack, 0) + } + retval + } + + private def isResourceConstraintSatisfied(container: Container): Boolean = { + container.getResource.getMemory >= (executorMemory + memoryOverhead) + } + + // A simple method to copy the split info map. + private def generateNodeToWeight( + conf: Configuration, + input: collection.Map[String, collection.Set[SplitInfo]] + ): (Map[String, Int], Map[String, Int]) = { + + if (input == null) { + return (Map[String, Int](), Map[String, Int]()) + } + + val hostToCount = new HashMap[String, Int] + val rackToCount = new HashMap[String, Int] + + for ((host, splits) <- input) { + val hostCount = hostToCount.getOrElse(host, 0) + hostToCount.put(host, hostCount + splits.size) + + val rack = YarnSparkHadoopUtil.lookupRack(conf, host) + if (rack != null) { + val rackCount = rackToCount.getOrElse(host, 0) + rackToCount.put(host, rackCount + splits.size) + } + } + + (hostToCount.toMap, rackToCount.toMap) + } + + private def internalReleaseContainer(container: Container) = { + releasedContainers.put(container.getId(), true) + releaseContainer(container) + } + + /** + * Called to allocate containers in the cluster. + * + * @param count Number of containers to allocate. + * If zero, should still contact RM (as a heartbeat). + * @return Response to the allocation request. + */ + protected def allocateContainers(count: Int): YarnAllocateResponse + + /** Called to release a previously allocated container. */ + protected def releaseContainer(container: Container): Unit + + /** + * Defines the interface for an allocate response from the RM. This is needed since the alpha + * and stable interfaces differ here in ways that cannot be fixed using other routes. + */ + protected trait YarnAllocateResponse { + + def getAllocatedContainers(): JList[Container] + + def getAvailableResources(): Resource + + def getCompletedContainersStatuses(): JList[ContainerStatus] - def allocateResources(): Unit - def getNumExecutorsFailed: Int - def getNumExecutorsRunning: Int + } } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 922d7d1a854a5..ed65e56b3e413 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -22,7 +22,7 @@ import scala.collection.{Map, Set} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.records._ -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler.SplitInfo /** @@ -45,7 +45,8 @@ trait YarnRMClient { sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], uiAddress: String, - uiHistoryAddress: String): YarnAllocator + uiHistoryAddress: String, + securityMgr: SecurityManager): YarnAllocator /** * Shuts down the AM. Guaranteed to only be called once. diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index ffe2731ca1d17..4a33e34c3bfc7 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -32,11 +32,11 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.StringInterner import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants +import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.util.RackResolver import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.deploy.history.HistoryServer +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.Utils @@ -156,19 +156,6 @@ object YarnSparkHadoopUtil { } } - def getUIHistoryAddress(sc: SparkContext, conf: SparkConf) : String = { - val eventLogDir = sc.eventLogger match { - case Some(logger) => logger.getApplicationLogDir() - case None => "" - } - val historyServerAddress = conf.get("spark.yarn.historyServer.address", "") - if (historyServerAddress != "" && eventLogDir != "") { - historyServerAddress + HistoryServer.UI_PATH_PREFIX + s"/$eventLogDir" - } else { - "" - } - } - /** * Escapes a string for inclusion in a command line executed by Yarn. Yarn executes commands * using `bash -c "command arg1 arg2"` and that means plain quoting doesn't really work. The @@ -225,4 +212,12 @@ object YarnSparkHadoopUtil { } } + private[spark] def getApplicationAclsForYarn(securityMgr: SecurityManager): + Map[ApplicationAccessType, String] = { + Map[ApplicationAccessType, String] ( + ApplicationAccessType.VIEW_APP -> securityMgr.getViewAcls, + ApplicationAccessType.MODIFY_APP -> securityMgr.getModifyAcls + ) + } + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index a5f537dd9de30..41c662cd7a6de 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -56,7 +56,6 @@ private[spark] class YarnClientSchedulerBackend( val driverPort = conf.get("spark.driver.port") val hostport = driverHost + ":" + driverPort conf.set("spark.driver.appUIAddress", sc.ui.appUIHostPort) - conf.set("spark.driver.appUIHistoryAddress", YarnSparkHadoopUtil.getUIHistoryAddress(sc, conf)) val argsArrayBuf = new ArrayBuffer[String]() argsArrayBuf += ( @@ -150,4 +149,7 @@ private[spark] class YarnClientSchedulerBackend( override def sufficientResourcesRegistered(): Boolean = { totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } + + override def applicationId(): Option[String] = Option(appId).map(_.toString()) + } diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala index 55665220a6f96..39436d0999663 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterSchedulerBackend.scala @@ -28,7 +28,7 @@ private[spark] class YarnClusterSchedulerBackend( extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { var totalExpectedExecutors = 0 - + if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 } @@ -47,4 +47,7 @@ private[spark] class YarnClusterSchedulerBackend( override def sufficientResourcesRegistered(): Boolean = { totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio } + + override def applicationId(): Option[String] = sc.getConf.getOption("spark.yarn.app.id") + } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 75db8ee6d468f..2cc5abb3a890c 100644 --- a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -23,7 +23,10 @@ import com.google.common.io.{ByteStreams, Files} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.{FunSuite, Matchers} -import org.apache.spark.{Logging, SparkConf} +import org.apache.hadoop.yarn.api.records.ApplicationAccessType + +import org.apache.spark.{Logging, SecurityManager, SparkConf} + class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { @@ -74,4 +77,75 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { yarnConf.get(key) should not be default.get(key) } + + test("test getApplicationAclsForYarn acls on") { + + // spark acls on, just pick up default user + val sparkConf = new SparkConf() + sparkConf.set("spark.acls.enable", "true") + + val securityMgr = new SecurityManager(sparkConf) + val acls = YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr) + + val viewAcls = acls.get(ApplicationAccessType.VIEW_APP) + val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) + + viewAcls match { + case Some(vacls) => { + val aclSet = vacls.split(',').map(_.trim).toSet + assert(aclSet.contains(System.getProperty("user.name", "invalid"))) + } + case None => { + fail() + } + } + modifyAcls match { + case Some(macls) => { + val aclSet = macls.split(',').map(_.trim).toSet + assert(aclSet.contains(System.getProperty("user.name", "invalid"))) + } + case None => { + fail() + } + } + } + + test("test getApplicationAclsForYarn acls on and specify users") { + + // default spark acls are on and specify acls + val sparkConf = new SparkConf() + sparkConf.set("spark.acls.enable", "true") + sparkConf.set("spark.ui.view.acls", "user1,user2") + sparkConf.set("spark.modify.acls", "user3,user4") + + val securityMgr = new SecurityManager(sparkConf) + val acls = YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr) + + val viewAcls = acls.get(ApplicationAccessType.VIEW_APP) + val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) + + viewAcls match { + case Some(vacls) => { + val aclSet = vacls.split(',').map(_.trim).toSet + assert(aclSet.contains("user1")) + assert(aclSet.contains("user2")) + assert(aclSet.contains(System.getProperty("user.name", "invalid"))) + } + case None => { + fail() + } + } + modifyAcls match { + case Some(macls) => { + val aclSet = macls.split(',').map(_.trim).toSet + assert(aclSet.contains("user3")) + assert(aclSet.contains("user4")) + assert(aclSet.contains(System.getProperty("user.name", "invalid"))) + } + case None => { + fail() + } + } + + } } diff --git a/yarn/pom.xml b/yarn/pom.xml index 3faaf053634d6..7fcd7ee0d4547 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/yarn/stable/pom.xml b/yarn/stable/pom.xml index b6c8456d06684..fd934b7726181 100644 --- a/yarn/stable/pom.xml +++ b/yarn/stable/pom.xml @@ -20,7 +20,7 @@ org.apache.spark yarn-parent_2.10 - 1.1.0-SNAPSHOT + 1.2.0-SNAPSHOT ../pom.xml diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 07ba0a4b30bd7..833be12982e71 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{SecurityManager, SparkConf, Logging} class ExecutorRunnable( @@ -46,7 +46,8 @@ class ExecutorRunnable( slaveId: String, hostname: String, executorMemory: Int, - executorCores: Int) + executorCores: Int, + securityMgr: SecurityManager) extends Runnable with ExecutorRunnableUtil with Logging { var rpc: YarnRPC = YarnRPC.create(conf) @@ -85,6 +86,8 @@ class ExecutorRunnable( logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands) + ctx.setApplicationACLs(YarnSparkHadoopUtil.getApplicationAclsForYarn(securityMgr)) + // Send the start request to the ContainerManager nmClient.startContainer(container, ctx) } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 4d5144989991f..5438f151ac0ad 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -17,36 +17,19 @@ package org.apache.spark.deploy.yarn -import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap} -import java.util.concurrent.atomic.AtomicInteger - import scala.collection.JavaConversions._ -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.spark.{Logging, SparkConf, SparkEnv} -import org.apache.spark.scheduler.{SplitInfo,TaskSchedulerImpl} -import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.Utils +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.scheduler.SplitInfo import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.records.ApplicationAttemptId -import org.apache.hadoop.yarn.api.records.{Container, ContainerId} -import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest} -import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.Records -// TODO: -// Too many params. -// Needs to be mt-safe -// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive - should -// make it more proactive and decoupled. - -// Note that right now, we assume all node asks as uniform in terms of capabilities and priority -// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for -// more info on how we are requesting for containers. - /** * Acquires resources for executors from a ResourceManager and launches executors in new containers. */ @@ -56,330 +39,24 @@ private[yarn] class YarnAllocationHandler( amClient: AMRMClient[ContainerRequest], appAttemptId: ApplicationAttemptId, args: ApplicationMasterArguments, - preferredNodes: collection.Map[String, collection.Set[SplitInfo]]) - extends YarnAllocator with Logging { - - // These three are locked on allocatedHostToContainersMap. Complementary data structures - // allocatedHostToContainersMap : containers which are running : host, Set - // allocatedContainerToHostMap: container to host mapping. - private val allocatedHostToContainersMap = - new HashMap[String, collection.mutable.Set[ContainerId]]() - - private val allocatedContainerToHostMap = new HashMap[ContainerId, String]() - - // allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an - // allocated node) - // As with the two data structures above, tightly coupled with them, and to be locked on - // allocatedHostToContainersMap - private val allocatedRackCount = new HashMap[String, Int]() - - // Containers which have been released. - private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]() - // Containers to be released in next request to RM - private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] - - // Additional memory overhead - in mb. - private def memoryOverhead: Int = sparkConf.getInt("spark.yarn.executor.memoryOverhead", - YarnSparkHadoopUtil.DEFAULT_MEMORY_OVERHEAD) - - // Number of container requests that have been sent to, but not yet allocated by the - // ApplicationMaster. - private val numPendingAllocate = new AtomicInteger() - private val numExecutorsRunning = new AtomicInteger() - // Used to generate a unique id per executor - private val executorIdCounter = new AtomicInteger() - private val lastResponseId = new AtomicInteger() - private val numExecutorsFailed = new AtomicInteger() - - private val maxExecutors = args.numExecutors - private val executorMemory = args.executorMemory - private val executorCores = args.executorCores - private val (preferredHostToCount, preferredRackToCount) = - generateNodeToWeight(conf, preferredNodes) - - override def getNumExecutorsRunning: Int = numExecutorsRunning.intValue + preferredNodes: collection.Map[String, collection.Set[SplitInfo]], + securityMgr: SecurityManager) + extends YarnAllocator(conf, sparkConf, args, preferredNodes, securityMgr) { - override def getNumExecutorsFailed: Int = numExecutorsFailed.intValue - - def isResourceConstraintSatisfied(container: Container): Boolean = { - container.getResource.getMemory >= (executorMemory + memoryOverhead) - } - - def releaseContainer(container: Container) { - val containerId = container.getId - pendingReleaseContainers.put(containerId, true) - amClient.releaseAssignedContainer(containerId) + override protected def releaseContainer(container: Container) = { + amClient.releaseAssignedContainer(container.getId()) } - override def allocateResources() = { - addResourceRequests(maxExecutors - numPendingAllocate.get() - numExecutorsRunning.get()) + override protected def allocateContainers(count: Int): YarnAllocateResponse = { + addResourceRequests(count) // We have already set the container request. Poll the ResourceManager for a response. // This doubles as a heartbeat if there are no pending container requests. val progressIndicator = 0.1f - val allocateResponse = amClient.allocate(progressIndicator) - - val allocatedContainers = allocateResponse.getAllocatedContainers() - if (allocatedContainers.size > 0) { - var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size) - - if (numPendingAllocateNow < 0) { - numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow) - } - - logDebug(""" - Allocated containers: %d - Current executor count: %d - Containers released: %s - Containers to-be-released: %s - Cluster resources: %s - """.format( - allocatedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers, - allocateResponse.getAvailableResources)) - - val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (container <- allocatedContainers) { - if (isResourceConstraintSatisfied(container)) { - // Add the accepted `container` to the host's list of already accepted, - // allocated containers - val host = container.getNodeId.getHost - val containersForHost = hostToContainers.getOrElseUpdate(host, - new ArrayBuffer[Container]()) - containersForHost += container - } else { - // Release container, since it doesn't satisfy resource constraints. - releaseContainer(container) - } - } - - // Find the appropriate containers to use. - // TODO: Cleanup this group-by... - val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() - val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() - - for (candidateHost <- hostToContainers.keySet) { - val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) - val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) - - val remainingContainersOpt = hostToContainers.get(candidateHost) - assert(remainingContainersOpt.isDefined) - var remainingContainers = remainingContainersOpt.get - - if (requiredHostCount >= remainingContainers.size) { - // Since we have <= required containers, add all remaining containers to - // `dataLocalContainers`. - dataLocalContainers.put(candidateHost, remainingContainers) - // There are no more free containers remaining. - remainingContainers = null - } else if (requiredHostCount > 0) { - // Container list has more containers than we need for data locality. - // Split the list into two: one based on the data local container count, - // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining - // containers. - val (dataLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredHostCount) - dataLocalContainers.put(candidateHost, dataLocal) - - // Invariant: remainingContainers == remaining - - // YARN has a nasty habit of allocating a ton of containers on a host - discourage this. - // Add each container in `remaining` to list of containers to release. If we have an - // insufficient number of containers, then the next allocation cycle will reallocate - // (but won't treat it as data local). - // TODO(harvey): Rephrase this comment some more. - for (container <- remaining) releaseContainer(container) - remainingContainers = null - } - - // For rack local containers - if (remainingContainers != null) { - val rack = YarnSparkHadoopUtil.lookupRack(conf, candidateHost) - if (rack != null) { - val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) - val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - - rackLocalContainers.getOrElse(rack, List()).size - - if (requiredRackCount >= remainingContainers.size) { - // Add all remaining containers to to `dataLocalContainers`. - dataLocalContainers.put(rack, remainingContainers) - remainingContainers = null - } else if (requiredRackCount > 0) { - // Container list has more containers that we need for data locality. - // Split the list into two: one based on the data local container count, - // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining - // containers. - val (rackLocal, remaining) = remainingContainers.splitAt( - remainingContainers.size - requiredRackCount) - val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, - new ArrayBuffer[Container]()) - - existingRackLocal ++= rackLocal - - remainingContainers = remaining - } - } - } - - if (remainingContainers != null) { - // Not all containers have been consumed - add them to the list of off-rack containers. - offRackContainers.put(candidateHost, remainingContainers) - } - } - - // Now that we have split the containers into various groups, go through them in order: - // first host-local, then rack-local, and finally off-rack. - // Note that the list we create below tries to ensure that not all containers end up within - // a host if there is a sufficiently large number of hosts/containers. - val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(dataLocalContainers) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(rackLocalContainers) - allocatedContainersToProcess ++= TaskSchedulerImpl.prioritizeContainers(offRackContainers) - - // Run each of the allocated containers. - for (container <- allocatedContainersToProcess) { - val numExecutorsRunningNow = numExecutorsRunning.incrementAndGet() - val executorHostname = container.getNodeId.getHost - val containerId = container.getId - - val executorMemoryOverhead = (executorMemory + memoryOverhead) - assert(container.getResource.getMemory >= executorMemoryOverhead) - - if (numExecutorsRunningNow > maxExecutors) { - logInfo("""Ignoring container %s at host %s, since we already have the required number of - containers for it.""".format(containerId, executorHostname)) - releaseContainer(container) - numExecutorsRunning.decrementAndGet() - } else { - val executorId = executorIdCounter.incrementAndGet().toString - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( - SparkEnv.driverActorSystemName, - sparkConf.get("spark.driver.host"), - sparkConf.get("spark.driver.port"), - CoarseGrainedSchedulerBackend.ACTOR_NAME) - - logInfo("Launching container %s for on host %s".format(containerId, executorHostname)) - - // To be safe, remove the container from `pendingReleaseContainers`. - pendingReleaseContainers.remove(containerId) - - val rack = YarnSparkHadoopUtil.lookupRack(conf, executorHostname) - allocatedHostToContainersMap.synchronized { - val containerSet = allocatedHostToContainersMap.getOrElseUpdate(executorHostname, - new HashSet[ContainerId]()) - - containerSet += containerId - allocatedContainerToHostMap.put(containerId, executorHostname) - - if (rack != null) { - allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) - } - } - logInfo("Launching ExecutorRunnable. driverUrl: %s, executorHostname: %s".format( - driverUrl, executorHostname)) - val executorRunnable = new ExecutorRunnable( - container, - conf, - sparkConf, - driverUrl, - executorId, - executorHostname, - executorMemory, - executorCores) - new Thread(executorRunnable).start() - } - } - logDebug(""" - Finished allocating %s containers (from %s originally). - Current number of executors running: %d, - releasedContainerList: %s, - pendingReleaseContainers: %s - """.format( - allocatedContainersToProcess, - allocatedContainers, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers)) - } - - val completedContainers = allocateResponse.getCompletedContainersStatuses() - if (completedContainers.size > 0) { - logDebug("Completed %d containers".format(completedContainers.size)) - - for (completedContainer <- completedContainers) { - val containerId = completedContainer.getContainerId - - if (pendingReleaseContainers.containsKey(containerId)) { - // YarnAllocationHandler already marked the container for release, so remove it from - // `pendingReleaseContainers`. - pendingReleaseContainers.remove(containerId) - } else { - // Decrement the number of executors running. The next iteration of - // the ApplicationMaster's reporting thread will take care of allocating. - numExecutorsRunning.decrementAndGet() - logInfo("Completed container %s (state: %s, exit status: %s)".format( - containerId, - completedContainer.getState, - completedContainer.getExitStatus())) - // Hadoop 2.2.X added a ContainerExitStatus we should switch to use - // there are some exit status' we shouldn't necessarily count against us, but for - // now I think its ok as none of the containers are expected to exit - if (completedContainer.getExitStatus() != 0) { - logInfo("Container marked as failed: " + containerId) - numExecutorsFailed.incrementAndGet() - } - } - - allocatedHostToContainersMap.synchronized { - if (allocatedContainerToHostMap.containsKey(containerId)) { - val hostOpt = allocatedContainerToHostMap.get(containerId) - assert(hostOpt.isDefined) - val host = hostOpt.get - - val containerSetOpt = allocatedHostToContainersMap.get(host) - assert(containerSetOpt.isDefined) - val containerSet = containerSetOpt.get - - containerSet.remove(containerId) - if (containerSet.isEmpty) { - allocatedHostToContainersMap.remove(host) - } else { - allocatedHostToContainersMap.update(host, containerSet) - } - - allocatedContainerToHostMap.remove(containerId) - - // TODO: Move this part outside the synchronized block? - val rack = YarnSparkHadoopUtil.lookupRack(conf, host) - if (rack != null) { - val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 - if (rackCount > 0) { - allocatedRackCount.put(rack, rackCount) - } else { - allocatedRackCount.remove(rack) - } - } - } - } - } - logDebug(""" - Finished processing %d completed containers. - Current number of executors running: %d, - releasedContainerList: %s, - pendingReleaseContainers: %s - """.format( - completedContainers.size, - numExecutorsRunning.get(), - releasedContainerList, - pendingReleaseContainers)) - } + new StableAllocateResponse(amClient.allocate(progressIndicator)) } - def createRackResourceRequests( + private def createRackResourceRequests( hostContainers: ArrayBuffer[ContainerRequest] ): ArrayBuffer[ContainerRequest] = { // Generate modified racks and new set of hosts under it before issuing requests. @@ -409,27 +86,13 @@ private[yarn] class YarnAllocationHandler( requestedContainers } - def allocatedContainersOnHost(host: String): Int = { - var retval = 0 - allocatedHostToContainersMap.synchronized { - retval = allocatedHostToContainersMap.getOrElse(host, Set()).size - } - retval - } - - def allocatedContainersOnRack(rack: String): Int = { - var retval = 0 - allocatedHostToContainersMap.synchronized { - retval = allocatedRackCount.getOrElse(rack, 0) - } - retval - } - private def addResourceRequests(numExecutors: Int) { val containerRequests: List[ContainerRequest] = - if (numExecutors <= 0 || preferredHostToCount.isEmpty) { - logDebug("numExecutors: " + numExecutors + ", host preferences: " + - preferredHostToCount.isEmpty) + if (numExecutors <= 0) { + logDebug("numExecutors: " + numExecutors) + List() + } else if (preferredHostToCount.isEmpty) { + logDebug("host preferences is empty") createResourceRequests( AllocationType.ANY, resource = null, @@ -472,15 +135,6 @@ private[yarn] class YarnAllocationHandler( amClient.addContainerRequest(request) } - if (numExecutors > 0) { - numPendingAllocate.addAndGet(numExecutors) - logInfo("Will Allocate %d executor containers, each with %d memory".format( - numExecutors, - (executorMemory + memoryOverhead))) - } else { - logDebug("Empty allocation request ...") - } - for (request <- containerRequests) { val nodes = request.getNodes var hostStr = if (nodes == null || nodes.isEmpty) { @@ -549,31 +203,10 @@ private[yarn] class YarnAllocationHandler( requests } - // A simple method to copy the split info map. - private def generateNodeToWeight( - conf: Configuration, - input: collection.Map[String, collection.Set[SplitInfo]] - ): (Map[String, Int], Map[String, Int]) = { - - if (input == null) { - return (Map[String, Int](), Map[String, Int]()) - } - - val hostToCount = new HashMap[String, Int] - val rackToCount = new HashMap[String, Int] - - for ((host, splits) <- input) { - val hostCount = hostToCount.getOrElse(host, 0) - hostToCount.put(host, hostCount + splits.size) - - val rack = YarnSparkHadoopUtil.lookupRack(conf, host) - if (rack != null){ - val rackCount = rackToCount.getOrElse(host, 0) - rackToCount.put(host, rackCount + splits.size) - } - } - - (hostToCount.toMap, rackToCount.toMap) + private class StableAllocateResponse(response: AllocateResponse) extends YarnAllocateResponse { + override def getAllocatedContainers() = response.getAllocatedContainers() + override def getAvailableResources() = response.getAvailableResources() + override def getCompletedContainersStatuses() = response.getCompletedContainersStatuses() } } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala index e8b8d9bc722bd..54bc6b14c44ce 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClientImpl.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.ConverterUtils import org.apache.hadoop.yarn.webapp.util.WebAppUtils -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.scheduler.SplitInfo import org.apache.spark.util.Utils @@ -46,7 +46,8 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC sparkConf: SparkConf, preferredNodeLocations: Map[String, Set[SplitInfo]], uiAddress: String, - uiHistoryAddress: String) = { + uiHistoryAddress: String, + securityMgr: SecurityManager) = { amClient = AMRMClient.createAMRMClient() amClient.init(conf) amClient.start() @@ -55,7 +56,7 @@ private class YarnRMClientImpl(args: ApplicationMasterArguments) extends YarnRMC logInfo("Registering the ApplicationMaster") amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) new YarnAllocationHandler(conf, sparkConf, amClient, getAttemptId(), args, - preferredNodeLocations) + preferredNodeLocations, securityMgr) } override def shutdown(status: FinalApplicationStatus, diagnostics: String = "") =