diff --git a/.rat-excludes b/.rat-excludes index 17cf6d0ed1cf3..8954330bd10a7 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -39,3 +39,6 @@ work .*\.q golden test.out/* +.*iml +python/metastore/service.properties +python/metastore/db.lck diff --git a/assembly/pom.xml b/assembly/pom.xml index b5e752c6cd1f6..923bf47f7076a 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -163,6 +163,16 @@ + + hive + + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + + + spark-ganglia-lgpl @@ -208,7 +218,7 @@ org.codehaus.mojo buildnumber-maven-plugin - 1.1 + 1.2 validate diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala index 70c7474a936dc..70a99b33d753c 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala @@ -220,20 +220,23 @@ object Bagel extends Logging { */ private def comp[K: Manifest, V <: Vertex, M <: Message[K], C]( sc: SparkContext, - grouped: RDD[(K, (Seq[C], Seq[V]))], + grouped: RDD[(K, (Iterable[C], Iterable[V]))], compute: (V, Option[C]) => (V, Array[M]), storageLevel: StorageLevel ): (RDD[(K, (V, Array[M]))], Int, Int) = { var numMsgs = sc.accumulator(0) var numActiveVerts = sc.accumulator(0) - val processed = grouped.flatMapValues { - case (_, vs) if vs.size == 0 => None - case (c, vs) => + val processed = grouped.mapValues(x => (x._1.iterator, x._2.iterator)) + .flatMapValues { + case (_, vs) if !vs.hasNext => None + case (c, vs) => { val (newVert, newMsgs) = - compute(vs(0), c match { - case Seq(comb) => Some(comb) - case Seq() => None - }) + compute(vs.next, + c.hasNext match { + case true => Some(c.next) + case false => None + } + ) numMsgs += newMsgs.size if (newVert.active) { @@ -241,6 +244,7 @@ object Bagel extends Logging { } Some((newVert, newMsgs)) + } }.persist(storageLevel) // Force evaluation of processed RDD for accurate performance measurements diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index 9c37fadb78d2f..69144e3e657bf 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -28,9 +28,9 @@ class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializ class TestMessage(val targetId: String) extends Message[String] with Serializable class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts { - + var sc: SparkContext = _ - + after { if (sc != null) { sc.stop() diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh index bef42df71ce01..2a2bb376fd71f 100755 --- a/bin/compute-classpath.sh +++ b/bin/compute-classpath.sh @@ -30,21 +30,7 @@ FWDIR="$(cd `dirname $0`/..; pwd)" # Build up classpath CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf" -# Support for interacting with Hive. Since hive pulls in a lot of dependencies that might break -# existing Spark applications, it is not included in the standard spark assembly. Instead, we only -# include it in the classpath if the user has explicitly requested it by running "sbt hive/assembly" -# Hopefully we will find a way to avoid uber-jars entirely and deploy only the needed packages in -# the future. -if [ -f "$FWDIR"/sql/hive/target/scala-$SCALA_VERSION/spark-hive-assembly-*.jar ]; then - - # Datanucleus jars do not work if only included in the uberjar as plugin.xml metadata is lost. - DATANUCLEUSJARS=$(JARS=("$FWDIR/lib_managed/jars"/datanucleus-*.jar); IFS=:; echo "${JARS[*]}") - CLASSPATH=$CLASSPATH:$DATANUCLEUSJARS - - ASSEMBLY_DIR="$FWDIR/sql/hive/target/scala-$SCALA_VERSION/" -else - ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION/" -fi +ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION" # First check if we have a dependencies jar. If so, include binary classes with the deps jar if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then @@ -59,7 +45,7 @@ if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes" CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes" - DEPS_ASSEMBLY_JAR=`ls "$ASSEMBLY_DIR"/spark*-assembly*hadoop*-deps.jar` + DEPS_ASSEMBLY_JAR=`ls "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar` CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR" else # Else use spark-assembly jar from either RELEASE or assembly directory @@ -71,6 +57,23 @@ else CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR" fi +# When Hive support is needed, Datanucleus jars must be included on the classpath. +# Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost. +# Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is +# built with Hive, so first check if the datanucleus jars exist, and then ensure the current Spark +# assembly is built for Hive, before actually populating the CLASSPATH with the jars. +# Note that this check order is faster (by up to half a second) in the case where Hive is not used. +num_datanucleus_jars=$(ls "$FWDIR"/lib_managed/jars/ 2>/dev/null | grep "datanucleus-.*\\.jar" | wc -l) +if [ $num_datanucleus_jars -gt 0 ]; then + AN_ASSEMBLY_JAR=${ASSEMBLY_JAR:-$DEPS_ASSEMBLY_JAR} + num_hive_files=$(jar tvf "$AN_ASSEMBLY_JAR" org/apache/hadoop/hive/ql/exec 2>/dev/null | wc -l) + if [ $num_hive_files -gt 0 ]; then + echo "Spark assembly has been built with Hive, including Datanucleus jars on classpath" 1>&2 + DATANUCLEUSJARS=$(echo "$FWDIR/lib_managed/jars"/datanucleus-*.jar | tr " " :) + CLASSPATH=$CLASSPATH:$DATANUCLEUSJARS + fi +fi + # Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1 if [[ $SPARK_TESTING == 1 ]]; then CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes" diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh index 476dd826551fd..d425f9feaac54 100644 --- a/bin/load-spark-env.sh +++ b/bin/load-spark-env.sh @@ -30,6 +30,9 @@ if [ -z "$SPARK_ENV_LOADED" ]; then use_conf_dir=${SPARK_CONF_DIR:-"$parent_dir/conf"} if [ -f "${use_conf_dir}/spark-env.sh" ]; then + # Promote all variable declarations to environment (exported) variables + set -a . "${use_conf_dir}/spark-env.sh" + set +a fi fi diff --git a/bin/pyspark b/bin/pyspark index 67e1f61eeb1e5..cad982bc33477 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -55,7 +55,8 @@ if [ -n "$IPYTHON_OPTS" ]; then IPYTHON=1 fi -if [[ "$IPYTHON" = "1" ]] ; then +# Only use ipython if no command line arguments were provided [SPARK-1134] +if [[ "$IPYTHON" = "1" && $# = 0 ]] ; then exec ipython $IPYTHON_OPTS else exec "$PYSPARK_PYTHON" "$@" diff --git a/bin/spark-class b/bin/spark-class index 0dcf0e156cb52..1b0d309cc5b1c 100755 --- a/bin/spark-class +++ b/bin/spark-class @@ -47,9 +47,9 @@ DEFAULT_MEM=${SPARK_MEM:-512m} SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true" -# Add java opts and memory settings for master, worker, executors, and repl. +# Add java opts and memory settings for master, worker, history server, executors, and repl. case "$1" in - # Master and Worker use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. + # Master, Worker, and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. 'org.apache.spark.deploy.master.Master') OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_MASTER_OPTS" OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} @@ -58,6 +58,10 @@ case "$1" in OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_WORKER_OPTS" OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} ;; + 'org.apache.spark.deploy.history.HistoryServer') + OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS $SPARK_HISTORY_OPTS" + OUR_JAVA_MEM=${SPARK_DAEMON_MEMORY:-$DEFAULT_MEM} + ;; # Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. 'org.apache.spark.executor.CoarseGrainedExecutorBackend') @@ -154,5 +158,3 @@ if [ "$SPARK_PRINT_LAUNCH_COMMAND" == "1" ]; then fi exec "$RUNNER" -cp "$CLASSPATH" $JAVA_OPTS "$@" - - diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index f488cfdbeceb6..4302c1b6b7ff4 100755 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -45,14 +45,17 @@ if "x%OUR_JAVA_MEM%"=="x" set OUR_JAVA_MEM=512m set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true -rem Add java opts and memory settings for master, worker, executors, and repl. -rem Master and Worker use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. +rem Add java opts and memory settings for master, worker, history server, executors, and repl. +rem Master, Worker and HistoryServer use SPARK_DAEMON_JAVA_OPTS (and specific opts) + SPARK_DAEMON_MEMORY. if "%1"=="org.apache.spark.deploy.master.Master" ( set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_MASTER_OPTS% if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% ) else if "%1"=="org.apache.spark.deploy.worker.Worker" ( set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_WORKER_OPTS% if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% +) else if "%1"=="org.apache.spark.deploy.history.HistoryServer" ( + set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% %SPARK_HISTORY_OPTS% + if not "x%SPARK_DAEMON_MEMORY%"=="x" set OUR_JAVA_MEM=%SPARK_DAEMON_MEMORY% rem Executors use SPARK_JAVA_OPTS + SPARK_EXECUTOR_MEMORY. ) else if "%1"=="org.apache.spark.executor.CoarseGrainedExecutorBackend" ( diff --git a/bin/spark-shell b/bin/spark-shell index fac006cf492ed..ea12d256b23a1 100755 --- a/bin/spark-shell +++ b/bin/spark-shell @@ -34,7 +34,7 @@ set -o posix FWDIR="$(cd `dirname $0`/..; pwd)" SPARK_REPL_OPTS="${SPARK_REPL_OPTS:-""}" -DEFAULT_MASTER="local" +DEFAULT_MASTER="local[*]" MASTER=${MASTER:-""} info_log=0 @@ -64,7 +64,7 @@ ${txtbld}OPTIONS${txtrst}: is followed by m for megabytes or g for gigabytes, e.g. "1g". -dm --driver-memory : The memory used by the Spark Shell, the number is followed by m for megabytes or g for gigabytes, e.g. "1g". - -m --master : A full string that describes the Spark Master, defaults to "local" + -m --master : A full string that describes the Spark Master, defaults to "local[*]" e.g. "spark://localhost:7077". --log-conf : Enables logging of the supplied SparkConf as INFO at start of the Spark Context. @@ -127,7 +127,7 @@ function set_spark_log_conf(){ function set_spark_master(){ if ! [[ "$1" =~ $ARG_FLAG_PATTERN ]]; then - MASTER="$1" + export MASTER="$1" else out_error "wrong format for $2" fi @@ -145,7 +145,7 @@ function resolve_spark_master(){ fi if [ -z "$MASTER" ]; then - MASTER="$DEFAULT_MASTER" + export MASTER="$DEFAULT_MASTER" fi } diff --git a/core/pom.xml b/core/pom.xml index eb6cc4d3105e9..a1bdd8ec68aeb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -117,12 +117,10 @@ com.twitter chill_${scala.binary.version} - 0.3.1 com.twitter chill-java - 0.3.1 commons-net @@ -150,7 +148,7 @@ json4s-jackson_${scala.binary.version} 3.2.6 @@ -159,10 +157,6 @@ - - it.unimi.dsi - fastutil - colt colt @@ -200,6 +194,53 @@ derby test + + org.tachyonproject + tachyon + 0.4.1-thrift + + + org.apache.hadoop + hadoop-client + + + org.apache.curator + curator-recipes + + + org.eclipse.jetty + jetty-jsp + + + org.eclipse.jetty + jetty-webapp + + + org.eclipse.jetty + jetty-server + + + org.eclipse.jetty + jetty-servlet + + + junit + junit + + + org.powermock + powermock-module-junit4 + + + org.powermock + powermock-api-mockito + + + org.apache.curator + curator-test + + + org.scalatest scalatest_${scala.binary.version} diff --git a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java index 9f13b39909481..840a1bd93bfbb 100644 --- a/core/src/main/java/org/apache/spark/api/java/StorageLevels.java +++ b/core/src/main/java/org/apache/spark/api/java/StorageLevels.java @@ -23,17 +23,18 @@ * Expose some commonly useful storage level constants. */ public class StorageLevels { - public static final StorageLevel NONE = create(false, false, false, 1); - public static final StorageLevel DISK_ONLY = create(true, false, false, 1); - public static final StorageLevel DISK_ONLY_2 = create(true, false, false, 2); - public static final StorageLevel MEMORY_ONLY = create(false, true, true, 1); - public static final StorageLevel MEMORY_ONLY_2 = create(false, true, true, 2); - public static final StorageLevel MEMORY_ONLY_SER = create(false, true, false, 1); - public static final StorageLevel MEMORY_ONLY_SER_2 = create(false, true, false, 2); - public static final StorageLevel MEMORY_AND_DISK = create(true, true, true, 1); - public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, true, 2); - public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, 1); - public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, 2); + public static final StorageLevel NONE = create(false, false, false, false, 1); + public static final StorageLevel DISK_ONLY = create(true, false, false, false, 1); + public static final StorageLevel DISK_ONLY_2 = create(true, false, false, false, 2); + public static final StorageLevel MEMORY_ONLY = create(false, true, false, true, 1); + public static final StorageLevel MEMORY_ONLY_2 = create(false, true, false, true, 2); + public static final StorageLevel MEMORY_ONLY_SER = create(false, true, false, false, 1); + public static final StorageLevel MEMORY_ONLY_SER_2 = create(false, true, false, false, 2); + public static final StorageLevel MEMORY_AND_DISK = create(true, true, false, true, 1); + public static final StorageLevel MEMORY_AND_DISK_2 = create(true, true, false, true, 2); + public static final StorageLevel MEMORY_AND_DISK_SER = create(true, true, false, false, 1); + public static final StorageLevel MEMORY_AND_DISK_SER_2 = create(true, true, false, false, 2); + public static final StorageLevel OFF_HEAP = create(false, false, true, false, 1); /** * Create a new StorageLevel object. @@ -42,7 +43,26 @@ public class StorageLevels { * @param deserialized saved as deserialized objects, if true * @param replication replication factor */ - public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) { - return StorageLevel.apply(useDisk, useMemory, deserialized, replication); + @Deprecated + public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, + int replication) { + return StorageLevel.apply(useDisk, useMemory, false, deserialized, replication); + } + + /** + * Create a new StorageLevel object. + * @param useDisk saved to disk, if true + * @param useMemory saved to memory, if true + * @param useOffHeap saved to Tachyon, if true + * @param deserialized saved as deserialized objects, if true + * @param replication replication factor + */ + public static StorageLevel create( + boolean useDisk, + boolean useMemory, + boolean useOffHeap, + boolean deserialized, + int replication) { + return StorageLevel.apply(useDisk, useMemory, useOffHeap, deserialized, replication); } } diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index fa75842047c6a..23f5fdd43631b 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -24,4 +24,4 @@ */ public interface FlatMapFunction extends Serializable { public Iterable call(T t) throws Exception; -} \ No newline at end of file +} diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index d1fdec072443d..c48e92f535ff5 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -24,4 +24,4 @@ */ public interface FlatMapFunction2 extends Serializable { public Iterable call(T1 t1, T2 t2) throws Exception; -} \ No newline at end of file +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index fe54c34ffb1da..599c3ac9b57c0 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -78,3 +78,12 @@ table.sortable thead { background-repeat: repeat-x; filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FFA4EDFF', endColorstr='#FF94DDFF', GradientType=0); } + +span.kill-link { + margin-right: 2px; + color: gray; +} + +span.kill-link a { + color: gray; +} diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index ceead59b79ed6..59fdf659c9e11 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -17,15 +17,18 @@ package org.apache.spark +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} /** + * :: DeveloperApi :: * A set of functions used to aggregate data. * * @param createCombiner function to create the initial value of the aggregation. * @param mergeValue function to merge a new value into the aggregation result. * @param mergeCombiners function to merge outputs from multiple mergeValue function. */ +@DeveloperApi case class Aggregator[K, V, C] ( createCombiner: V => C, mergeValue: (C, V) => C, diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala new file mode 100644 index 0000000000000..54e08d7866f75 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.lang.ref.{ReferenceQueue, WeakReference} + +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD + +/** + * Classes that represent cleaning tasks. + */ +private sealed trait CleanupTask +private case class CleanRDD(rddId: Int) extends CleanupTask +private case class CleanShuffle(shuffleId: Int) extends CleanupTask +private case class CleanBroadcast(broadcastId: Long) extends CleanupTask + +/** + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) + +/** + * An asynchronous cleaner for RDD, shuffle, and broadcast state. + * + * This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest, + * to be processed when the associated object goes out of scope of the application. Actual + * cleanup is performed in a separate daemon thread. + */ +private[spark] class ContextCleaner(sc: SparkContext) extends Logging { + + private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] + with SynchronizedBuffer[CleanupTaskWeakReference] + + private val referenceQueue = new ReferenceQueue[AnyRef] + + private val listeners = new ArrayBuffer[CleanerListener] + with SynchronizedBuffer[CleanerListener] + + private val cleaningThread = new Thread() { override def run() { keepCleaning() }} + + /** + * Whether the cleaning thread will block on cleanup tasks. + * This is set to true only for tests. + */ + private val blockOnCleanupTasks = sc.conf.getBoolean( + "spark.cleaner.referenceTracking.blocking", false) + + @volatile private var stopped = false + + /** Attach a listener object to get information of when objects are cleaned. */ + def attachListener(listener: CleanerListener) { + listeners += listener + } + + /** Start the cleaner. */ + def start() { + cleaningThread.setDaemon(true) + cleaningThread.setName("Spark Context Cleaner") + cleaningThread.start() + } + + /** Stop the cleaner. */ + def stop() { + stopped = true + } + + /** Register a RDD for cleanup when it is garbage collected. */ + def registerRDDForCleanup(rdd: RDD[_]) { + registerForCleanup(rdd, CleanRDD(rdd.id)) + } + + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ + def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { + registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) + } + + /** Register a Broadcast for cleanup when it is garbage collected. */ + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) + } + + /** Register an object for cleanup. */ + private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { + referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) + } + + /** Keep cleaning RDD, shuffle, and broadcast state. */ + private def keepCleaning() { + while (!stopped) { + try { + val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) + reference.map(_.task).foreach { task => + logDebug("Got cleaning task " + task) + referenceBuffer -= reference.get + task match { + case CleanRDD(rddId) => + doCleanupRDD(rddId, blocking = blockOnCleanupTasks) + case CleanShuffle(shuffleId) => + doCleanupShuffle(shuffleId, blocking = blockOnCleanupTasks) + case CleanBroadcast(broadcastId) => + doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks) + } + } + } catch { + case t: Throwable => logError("Error in cleaning thread", t) + } + } + } + + /** Perform RDD cleanup. */ + def doCleanupRDD(rddId: Int, blocking: Boolean) { + try { + logDebug("Cleaning RDD " + rddId) + sc.unpersistRDD(rddId, blocking) + listeners.foreach(_.rddCleaned(rddId)) + logInfo("Cleaned RDD " + rddId) + } catch { + case t: Throwable => logError("Error cleaning RDD " + rddId, t) + } + } + + /** Perform shuffle cleanup, asynchronously. */ + def doCleanupShuffle(shuffleId: Int, blocking: Boolean) { + try { + logDebug("Cleaning shuffle " + shuffleId) + mapOutputTrackerMaster.unregisterShuffle(shuffleId) + blockManagerMaster.removeShuffle(shuffleId, blocking) + listeners.foreach(_.shuffleCleaned(shuffleId)) + logInfo("Cleaned shuffle " + shuffleId) + } catch { + case t: Throwable => logError("Error cleaning shuffle " + shuffleId, t) + } + } + + /** Perform broadcast cleanup. */ + def doCleanupBroadcast(broadcastId: Long, blocking: Boolean) { + try { + logDebug("Cleaning broadcast " + broadcastId) + broadcastManager.unbroadcast(broadcastId, true, blocking) + listeners.foreach(_.broadcastCleaned(broadcastId)) + logInfo("Cleaned broadcast " + broadcastId) + } catch { + case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t) + } + } + + private def blockManagerMaster = sc.env.blockManager.master + private def broadcastManager = sc.env.broadcastManager + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + // Used for testing. These methods explicitly blocks until cleanup is completed + // to ensure that more reliable testing. +} + +private object ContextCleaner { + private val REF_QUEUE_POLL_TIMEOUT = 100 +} + +/** + * Listener class used for testing when any item has been cleaned by the Cleaner class. + */ +private[spark] trait CleanerListener { + def rddCleaned(rddId: Int) + def shuffleCleaned(shuffleId: Int) + def broadcastCleaned(broadcastId: Long) +} diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 3132dcf745e19..2c31cc20211ff 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -17,19 +17,24 @@ package org.apache.spark +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer /** + * :: DeveloperApi :: * Base class for dependencies. */ +@DeveloperApi abstract class Dependency[T](val rdd: RDD[T]) extends Serializable /** + * :: DeveloperApi :: * Base class for dependencies where each partition of the parent RDD is used by at most one * partition of the child RDD. Narrow dependencies allow for pipelined execution. */ +@DeveloperApi abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { /** * Get the parent partitions for a child partition. @@ -41,6 +46,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { /** + * :: DeveloperApi :: * Represents a dependency on the output of a shuffle stage. * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output @@ -48,6 +54,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { * the default serializer, as specified by `spark.serializer` config option, will * be used. */ +@DeveloperApi class ShuffleDependency[K, V]( @transient rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, @@ -55,24 +62,30 @@ class ShuffleDependency[K, V]( extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { val shuffleId: Int = rdd.context.newShuffleId() + + rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } /** + * :: DeveloperApi :: * Represents a one-to-one dependency between partitions of the parent and child RDDs. */ +@DeveloperApi class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) { override def getParents(partitionId: Int) = List(partitionId) } /** + * :: DeveloperApi :: * Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs. * @param rdd the parent RDD * @param inStart the start of the range in the parent RDD * @param outStart the start of the range in the child RDD * @param length the length of the range */ +@DeveloperApi class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int) extends NarrowDependency[T](rdd) { diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index f2decd14ef6d9..1e4dec86a0530 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -21,13 +21,16 @@ import scala.concurrent._ import scala.concurrent.duration.Duration import scala.util.Try +import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{JobFailed, JobSucceeded, JobWaiter} /** + * :: Experimental :: * A future for the result of an action to support cancellation. This is an extension of the * Scala Future interface to support cancellation. */ +@Experimental trait FutureAction[T] extends Future[T] { // Note that we redefine methods of the Future trait here explicitly so we can specify a different // documentation (with reference to the word "action"). @@ -84,9 +87,11 @@ trait FutureAction[T] extends Future[T] { /** + * :: Experimental :: * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include * count, collect, reduce. */ +@Experimental class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T) extends FutureAction[T] { @@ -141,17 +146,19 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: private def awaitResult(): Try[T] = { jobWaiter.awaitResult() match { case JobSucceeded => scala.util.Success(resultFunc) - case JobFailed(e: Exception, _) => scala.util.Failure(e) + case JobFailed(e: Exception) => scala.util.Failure(e) } } } /** + * :: Experimental :: * A [[FutureAction]] for actions that could trigger multiple Spark jobs. Examples include take, * takeSample. Cancellation works by setting the cancelled flag to true and interrupting the * action thread if it is being blocked by a job. */ +@Experimental class ComplexFutureAction[T] extends FutureAction[T] { // Pointer to the thread that is executing the action. It is set when the action is run. diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 3d7692ea8a49e..a6e300d345786 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -24,13 +24,13 @@ import com.google.common.io.Files import org.apache.spark.util.Utils private[spark] class HttpFileServer(securityManager: SecurityManager) extends Logging { - + var baseDir : File = null var fileDir : File = null var jarDir : File = null var httpServer : HttpServer = null var serverUri : String = null - + def initialize() { baseDir = Utils.createTempDir() fileDir = new File(baseDir, "files") @@ -43,24 +43,24 @@ private[spark] class HttpFileServer(securityManager: SecurityManager) extends Lo serverUri = httpServer.uri logDebug("HTTP file server started at: " + serverUri) } - + def stop() { httpServer.stop() } - + def addFile(file: File) : String = { addFileToDir(file, fileDir) serverUri + "/files/" + file.getName } - + def addJar(file: File) : String = { addFileToDir(file, jarDir) serverUri + "/jars/" + file.getName } - + def addFileToDir(file: File, dir: File) : String = { Files.copy(file, new File(dir, file.getName)) dir + "/" + file.getName } - + } diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index cb5df25fa48df..7e9b517f901a2 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -83,19 +83,19 @@ private[spark] class HttpServer(resourceBase: File, securityManager: SecurityMan } } - /** + /** * Setup Jetty to the HashLoginService using a single user with our * shared secret. Configure it to use DIGEST-MD5 authentication so that the password * isn't passed in plaintext. */ private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = { val constraint = new Constraint() - // use DIGEST-MD5 as the authentication mechanism + // use DIGEST-MD5 as the authentication mechanism constraint.setName(Constraint.__DIGEST_AUTH) constraint.setRoles(Array("user")) constraint.setAuthenticate(true) constraint.setDataConstraint(Constraint.DC_NONE) - + val cm = new ConstraintMapping() cm.setConstraint(constraint) cm.setPathSpec("/*") diff --git a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala index 9b1601d5b95fa..fd1802ba2f984 100644 --- a/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala +++ b/core/src/main/scala/org/apache/spark/InterruptibleIterator.scala @@ -21,7 +21,7 @@ package org.apache.spark * An iterator that wraps around an existing iterator to provide task killing functionality. * It works by checking the interrupted flag in [[TaskContext]]. */ -class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T]) +private[spark] class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T]) extends Iterator[T] { def hasNext: Boolean = !context.interrupted && delegate.hasNext diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 7423082e34f47..50d8e93e1f0d7 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -21,11 +21,19 @@ import org.apache.log4j.{LogManager, PropertyConfigurator} import org.slf4j.{Logger, LoggerFactory} import org.slf4j.impl.StaticLoggerBinder +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Utils + /** + * :: DeveloperApi :: * Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows * logging messages at different levels using methods that only evaluate parameters lazily if the * log level is enabled. + * + * NOTE: DO NOT USE this class outside of Spark. It is intended as an internal utility. + * This will likely be changed or removed in future releases. */ +@DeveloperApi trait Logging { // Make the log field transient so that objects with Logging can // be serialized and used on another machine @@ -53,7 +61,7 @@ trait Logging { protected def logDebug(msg: => String) { if (log.isDebugEnabled) log.debug(msg) } - + protected def logTrace(msg: => String) { if (log.isTraceEnabled) log.trace(msg) } @@ -108,12 +116,11 @@ trait Logging { val log4jInitialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements if (!log4jInitialized && usingLog4j) { val defaultLogProps = "org/apache/spark/log4j-defaults.properties" - val classLoader = this.getClass.getClassLoader - Option(classLoader.getResource(defaultLogProps)) match { - case Some(url) => + Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match { + case Some(url) => PropertyConfigurator.configure(url) log.info(s"Using Spark's default log4j profile: $defaultLogProps") - case None => + case None => System.err.println(s"Spark was unable to load $defaultLogProps") } } @@ -128,4 +135,16 @@ trait Logging { private object Logging { @volatile private var initialized = false val initLock = new Object() + try { + // We use reflection here to handle the case where users remove the + // slf4j-to-jul bridge order to route their logs to JUL. + val bridgeClass = Class.forName("org.slf4j.bridge.SLF4JBridgeHandler") + bridgeClass.getMethod("removeHandlersForRootLogger").invoke(null) + val installed = bridgeClass.getMethod("isInstalled").invoke(null).asInstanceOf[Boolean] + if (!installed) { + bridgeClass.getMethod("install").invoke(null) + } + } catch { + case e: ClassNotFoundException => // can't log anything yet so just fail silently + } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 80cbf951cb70e..ee82d9fa7874b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -20,21 +20,21 @@ package org.apache.spark import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.HashSet +import scala.collection.mutable.{HashSet, HashMap, Map} import scala.concurrent.Await import akka.actor._ import akka.pattern.ask - import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.{AkkaUtils, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +/** Actor class for MapOutputTrackerMaster */ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) extends Actor with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) @@ -65,26 +65,41 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster } } -private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { - +/** + * Class that keeps track of the location of the map output of + * a stage. This is abstract because different versions of MapOutputTracker + * (driver and worker) use different HashMap to store its metadata. + */ +private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { private val timeout = AkkaUtils.askTimeout(conf) - // Set to the MapOutputTrackerActor living on the driver + /** Set to the MapOutputTrackerActor living on the driver. */ var trackerActor: ActorRef = _ - protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] + /** + * This HashMap has different behavior for the master and the workers. + * + * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks. + * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the + * master's corresponding HashMap. + */ + protected val mapStatuses: Map[Int, Array[MapStatus]] - // Incremented every time a fetch fails so that client nodes know to clear - // their cache of map output locations if this happens. + /** + * Incremented every time a fetch fails so that client nodes know to clear + * their cache of map output locations if this happens. + */ protected var epoch: Long = 0 - protected val epochLock = new java.lang.Object + protected val epochLock = new AnyRef - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) + /** Remembers which map output locations are currently being fetched on a worker. */ + private val fetching = new HashSet[Int] - // Send a message to the trackerActor and get its result within a default timeout, or - // throw a SparkException if this fails. - private def askTracker(message: Any): Any = { + /** + * Send a message to the trackerActor and get its result within a default timeout, or + * throw a SparkException if this fails. + */ + protected def askTracker(message: Any): Any = { try { val future = trackerActor.ask(message)(timeout) Await.result(future, timeout) @@ -94,17 +109,17 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } } - // Send a one-way message to the trackerActor, to which we expect it to reply with true. - private def communicate(message: Any) { + /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ + protected def sendTracker(message: Any) { if (askTracker(message) != true) { throw new SparkException("Error reply received from MapOutputTracker") } } - // Remembers which map output locations are currently being fetched on a worker - private val fetching = new HashSet[Int] - - // Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle + /** + * Called from executors to get the server URIs and output sizes of the map outputs of + * a given shuffle. + */ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { @@ -152,8 +167,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { fetchedStatuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } - } - else { + } else { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing all output locations for shuffle " + shuffleId)) } @@ -164,27 +178,18 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } } - protected def cleanup(cleanupTime: Long) { - mapStatuses.clearOldValues(cleanupTime) - } - - def stop() { - communicate(StopMapOutputTracker) - mapStatuses.clear() - metadataCleaner.cancel() - trackerActor = null - } - - // Called to get current epoch number + /** Called to get current epoch number. */ def getEpoch: Long = { epochLock.synchronized { return epoch } } - // Called on workers to update the epoch number, potentially clearing old outputs - // because of a fetch failure. (Each worker task calls this with the latest epoch - // number on the master at the time it was created.) + /** + * Called from executors to update the epoch number, potentially clearing old outputs + * because of a fetch failure. Each worker task calls this with the latest epoch + * number on the master at the time it was created. + */ def updateEpoch(newEpoch: Long) { epochLock.synchronized { if (newEpoch > epoch) { @@ -194,17 +199,40 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { } } } + + /** Unregister shuffle data. */ + def unregisterShuffle(shuffleId: Int) { + mapStatuses.remove(shuffleId) + } + + /** Stop the tracker. */ + def stop() { } } +/** + * MapOutputTracker for the driver. This uses TimeStampedHashMap to keep track of map + * output information, which allows old output information based on a TTL. + */ private[spark] class MapOutputTrackerMaster(conf: SparkConf) extends MapOutputTracker(conf) { - // Cache a serialized version of the output statuses for each shuffle to send them out faster + /** Cache a serialized version of the output statuses for each shuffle to send them out faster */ private var cacheEpoch = epoch - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]] + + /** + * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master, + * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set). + * Other than these two scenarios, nothing should be dropped from this HashMap. + */ + protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() + private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() + + // For cleaning up TimeStampedHashMaps + private val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) def registerShuffle(shuffleId: Int, numMaps: Int) { - if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) { + if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") } } @@ -216,6 +244,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + /** Register multiple map output information for the given shuffle */ def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) { mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses) if (changeEpoch) { @@ -223,6 +252,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + /** Unregister map output information of the given shuffle, mapper and block manager */ def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) { val arrayOpt = mapStatuses.get(shuffleId) if (arrayOpt.isDefined && arrayOpt.get != null) { @@ -238,6 +268,17 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) } } + /** Unregister shuffle data */ + override def unregisterShuffle(shuffleId: Int) { + mapStatuses.remove(shuffleId) + cachedSerializedStatuses.remove(shuffleId) + } + + /** Check if the given shuffle is being tracked */ + def containsShuffle(shuffleId: Int): Boolean = { + cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId) + } + def incrementEpoch() { epochLock.synchronized { epoch += 1 @@ -274,23 +315,26 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) bytes } - protected override def cleanup(cleanupTime: Long) { - super.cleanup(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) - } - override def stop() { - super.stop() + sendTracker(StopMapOutputTracker) + mapStatuses.clear() + trackerActor = null + metadataCleaner.cancel() cachedSerializedStatuses.clear() } - override def updateEpoch(newEpoch: Long) { - // This might be called on the MapOutputTrackerMaster if we're running in local mode. + private def cleanup(cleanupTime: Long) { + mapStatuses.clearOldValues(cleanupTime) + cachedSerializedStatuses.clearOldValues(cleanupTime) } +} - def has(shuffleId: Int): Boolean = { - cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId) - } +/** + * MapOutputTracker for the workers, which fetches map output information from the driver's + * MapOutputTrackerMaster. + */ +private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { + protected val mapStatuses = new HashMap[Int, Array[MapStatus]] } private[spark] object MapOutputTracker { diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala index 87914a061f5d7..27892dbd2a0bc 100644 --- a/core/src/main/scala/org/apache/spark/Partition.scala +++ b/core/src/main/scala/org/apache/spark/Partition.scala @@ -25,7 +25,7 @@ trait Partition extends Serializable { * Get the split's index within its parent RDD */ def index: Int - + // A better default implementation of HashCode override def hashCode(): Int = index } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2237ee3bb7aad..b52f2d4f416b2 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -25,93 +25,93 @@ import org.apache.hadoop.io.Text import org.apache.spark.deploy.SparkHadoopUtil -/** - * Spark class responsible for security. - * +/** + * Spark class responsible for security. + * * In general this class should be instantiated by the SparkEnv and most components - * should access it from that. There are some cases where the SparkEnv hasn't been + * should access it from that. There are some cases where the SparkEnv hasn't been * initialized yet and this class must be instantiated directly. - * + * * Spark currently supports authentication via a shared secret. * Authentication can be configured to be on via the 'spark.authenticate' configuration - * parameter. This parameter controls whether the Spark communication protocols do + * parameter. This parameter controls whether the Spark communication protocols do * authentication using the shared secret. This authentication is a basic handshake to * make sure both sides have the same shared secret and are allowed to communicate. - * If the shared secret is not identical they will not be allowed to communicate. - * - * The Spark UI can also be secured by using javax servlet filters. A user may want to - * secure the UI if it has data that other users should not be allowed to see. The javax - * servlet filter specified by the user can authenticate the user and then once the user - * is logged in, Spark can compare that user versus the view acls to make sure they are - * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' + * If the shared secret is not identical they will not be allowed to communicate. + * + * The Spark UI can also be secured by using javax servlet filters. A user may want to + * secure the UI if it has data that other users should not be allowed to see. The javax + * servlet filter specified by the user can authenticate the user and then once the user + * is logged in, Spark can compare that user versus the view acls to make sure they are + * authorized to view the UI. The configs 'spark.ui.acls.enable' and 'spark.ui.view.acls' * control the behavior of the acls. Note that the person who started the application * always has view access to the UI. * * Spark does not currently support encryption after authentication. - * + * * At this point spark has multiple communication protocols that need to be secured and * different underlying mechanisms are used depending on the protocol: * - * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. - * Akka remoting allows you to specify a secure cookie that will be exchanged - * and ensured to be identical in the connection handshake between the client - * and the server. If they are not identical then the client will be refused - * to connect to the server. There is no control of the underlying - * authentication mechanism so its not clear if the password is passed in + * - Akka -> The only option here is to use the Akka Remote secure-cookie functionality. + * Akka remoting allows you to specify a secure cookie that will be exchanged + * and ensured to be identical in the connection handshake between the client + * and the server. If they are not identical then the client will be refused + * to connect to the server. There is no control of the underlying + * authentication mechanism so its not clear if the password is passed in * plaintext or uses DIGEST-MD5 or some other mechanism. * Akka also has an option to turn on SSL, this option is not currently supported * but we could add a configuration option in the future. - * - * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty - * for the HttpServer. Jetty supports multiple authentication mechanisms - - * Basic, Digest, Form, Spengo, etc. It also supports multiple different login + * + * - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty + * for the HttpServer. Jetty supports multiple authentication mechanisms - + * Basic, Digest, Form, Spengo, etc. It also supports multiple different login * services - Hash, JAAS, Spnego, JDBC, etc. Spark currently uses the HashLoginService - * to authenticate using DIGEST-MD5 via a single user and the shared secret. + * to authenticate using DIGEST-MD5 via a single user and the shared secret. * Since we are using DIGEST-MD5, the shared secret is not passed on the wire * in plaintext. * We currently do not support SSL (https), but Jetty can be configured to use it * so we could add a configuration option for this in the future. - * + * * The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5. - * Any clients must specify the user and password. There is a default + * Any clients must specify the user and password. There is a default * Authenticator installed in the SecurityManager to how it does the authentication * and in this case gets the user name and password from the request. * - * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously - * exchange messages. For this we use the Java SASL - * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 + * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously + * exchange messages. For this we use the Java SASL + * (Simple Authentication and Security Layer) API and again use DIGEST-MD5 * as the authentication mechanism. This means the shared secret is not passed * over the wire in plaintext. * Note that SASL is pluggable as to what mechanism it uses. We currently use * DIGEST-MD5 but this could be changed to use Kerberos or other in the future. * Spark currently supports "auth" for the quality of protection, which means * the connection is not supporting integrity or privacy protection (encryption) - * after authentication. SASL also supports "auth-int" and "auth-conf" which + * after authentication. SASL also supports "auth-int" and "auth-conf" which * SPARK could be support in the future to allow the user to specify the quality - * of protection they want. If we support those, the messages will also have to + * of protection they want. If we support those, the messages will also have to * be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's. - * - * Since the connectionManager does asynchronous messages passing, the SASL + * + * Since the connectionManager does asynchronous messages passing, the SASL * authentication is a bit more complex. A ConnectionManager can be both a client * and a Server, so for a particular connection is has to determine what to do. - * A ConnectionId was added to be able to track connections and is used to + * A ConnectionId was added to be able to track connections and is used to * match up incoming messages with connections waiting for authentication. * If its acting as a client and trying to send a message to another ConnectionManager, * it blocks the thread calling sendMessage until the SASL negotiation has occurred. * The ConnectionManager tracks all the sendingConnections using the ConnectionId * and waits for the response from the server and does the handshake. * - * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters + * - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters * can be used. Yarn requires a specific AmIpFilter be installed for security to work * properly. For non-Yarn deployments, users can write a filter to go through a * companies normal login service. If an authentication filter is in place then the * SparkUI can be configured to check the logged in user against the list of users who * have view acls to see if that user is authorized. - * The filters can also be used for many different purposes. For instance filters + * The filters can also be used for many different purposes. For instance filters * could be used for logging, encryption, or compression. - * + * * The exact mechanisms used to generate/distributed the shared secret is deployment specific. - * + * * For Yarn deployments, the secret is automatically generated using the Akka remote * Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed * around via the Hadoop RPC mechanism. Hadoop RPC can be configured to support different levels @@ -121,7 +121,7 @@ import org.apache.spark.deploy.SparkHadoopUtil * to reduce the possibility of web based attacks through YARN. Hadoop can be configured to use * filters to do authentication. That authentication then happens via the ResourceManager Proxy * and Spark will use that to do authorization against the view acls. - * + * * For other Spark deployments, the shared secret must be specified via the * spark.authenticate.secret config. * All the nodes (Master and Workers) and the applications need to have the same shared secret. @@ -152,7 +152,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { " are ui acls enabled: " + uiAclsOn + " users with view permissions: " + viewAcls.toString()) // Set our own authenticator to properly negotiate user/password for HTTP connections. - // This is needed by the HTTP client fetching from the HttpServer. Put here so its + // This is needed by the HTTP client fetching from the HttpServer. Put here so its // only set once. if (authOn) { Authenticator.setDefault( @@ -214,12 +214,12 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging { def uiAclsEnabled(): Boolean = uiAclsOn /** - * Checks the given user against the view acl list to see if they have + * Checks the given user against the view acl list to see if they have * authorization to view the UI. If the UI acls must are disabled * via spark.ui.acls.enable, all users have view access. - * + * * @param user to see if is authorized - * @return true is the user has permission, otherwise false + * @return true is the user has permission, otherwise false */ def checkUIViewPermissions(user: String): Boolean = { if (uiAclsEnabled() && (user != null) && (!viewAcls.contains(user))) false else true diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala index dff665cae6cb6..e50b9ac2291f9 100644 --- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala +++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala @@ -23,6 +23,9 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.ObjectWritable import org.apache.hadoop.io.Writable +import org.apache.spark.annotation.DeveloperApi + +@DeveloperApi class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable { def value = t override def toString = t.toString diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b23accbbb9410..456070fa7c5ef 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -19,14 +19,13 @@ package org.apache.spark import java.io._ import java.net.URI -import java.util.{Properties, UUID} import java.util.concurrent.atomic.AtomicInteger - +import java.util.{Properties, UUID} +import java.util.UUID.randomUUID import scala.collection.{Map, Set} import scala.collection.generic.Growable import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.reflect.{ClassTag, classTag} - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable} @@ -35,8 +34,10 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} +import org.apache.spark.input.WholeTextFileInputFormat import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ @@ -45,25 +46,38 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils} /** + * :: DeveloperApi :: * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * * @param config a Spark Config object describing the application configuration. Any settings in * this config overrides the default configs as well as system properties. - * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. Can - * be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] - * from a list of input files or InputFormats for the application. */ -class SparkContext( - config: SparkConf, - // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, - // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It - // contains a map from hostname to a list of input format splits on the host. - val preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) - extends Logging { + +@DeveloperApi +class SparkContext(config: SparkConf) extends Logging { + + // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, + // etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It + // contains a map from hostname to a list of input format splits on the host. + private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map() + + /** + * :: DeveloperApi :: + * Alternative constructor for setting preferred locations where Spark will create executors. + * + * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. Ca + * be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] + * from a list of input files or InputFormats for the application. + */ + @DeveloperApi + def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { + this(config) + this.preferredNodeLocationData = preferredNodeLocationData + } /** * Alternative constructor that allows setting common Spark properties directly @@ -93,10 +107,45 @@ class SparkContext( environment: Map[String, String] = Map(), preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) = { - this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment), - preferredNodeLocationData) + this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment)) + this.preferredNodeLocationData = preferredNodeLocationData } + // NOTE: The below constructors could be consolidated using default arguments. Due to + // Scala bug SI-8479, however, this causes the compile step to fail when generating docs. + // Until we have a good workaround for that bug the constructors remain broken out. + + /** + * Alternative constructor that allows setting common Spark properties directly + * + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI. + */ + private[spark] def this(master: String, appName: String) = + this(master, appName, null, Nil, Map(), Map()) + + /** + * Alternative constructor that allows setting common Spark properties directly + * + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI. + * @param sparkHome Location where Spark is installed on cluster nodes. + */ + private[spark] def this(master: String, appName: String, sparkHome: String) = + this(master, appName, sparkHome, Nil, Map(), Map()) + + /** + * Alternative constructor that allows setting common Spark properties directly + * + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI. + * @param sparkHome Location where Spark is installed on cluster nodes. + * @param jars Collection of JARs to send to the cluster. These can be paths on the local file + * system or HDFS, HTTP, HTTPS, or FTP URLs. + */ + private[spark] def this(master: String, appName: String, sparkHome: String, jars: Seq[String]) = + this(master, appName, sparkHome, jars, Map(), Map()) + private[spark] val conf = config.clone() /** @@ -129,6 +178,11 @@ class SparkContext( val master = conf.get("spark.master") val appName = conf.get("spark.app.name") + // Generate the random name for a temp folder in Tachyon + // Add a timestamp as the suffix here to make it more safe + val tachyonFolderName = "spark-" + randomUUID.toString() + conf.set("spark.tachyonStore.folderName", tachyonFolderName) + val isLocal = (master == "local" || master.startsWith("local[")) if (master == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true") @@ -152,28 +206,24 @@ class SparkContext( private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] + private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) // Initialize the Spark UI, registering all associated listeners private[spark] val ui = new SparkUI(this) ui.bind() - ui.start() // Optionally log Spark events private[spark] val eventLogger: Option[EventLoggingListener] = { if (conf.getBoolean("spark.eventLog.enabled", false)) { val logger = new EventLoggingListener(appName, conf) + logger.start() listenerBus.addListener(logger) Some(logger) } else None } - // Information needed to replay logged events, if any - private[spark] val eventLoggingInfo: Option[EventLoggingInfo] = - eventLogger.map { logger => Some(logger.info) }.getOrElse(None) - // At this point, all relevant SparkListeners have been registered, so begin releasing events listenerBus.start() @@ -184,7 +234,7 @@ class SparkContext( jars.foreach(addJar) } - def warnSparkMem(value: String): String = { + private def warnSparkMem(value: String): String = { logWarning("Using SPARK_MEM to set amount of memory to use per executor process is " + "deprecated, please use spark.executor.memory instead.") value @@ -228,7 +278,17 @@ class SparkContext( @volatile private[spark] var dagScheduler = new DAGScheduler(this) dagScheduler.start() + private[spark] val cleaner: Option[ContextCleaner] = { + if (conf.getBoolean("spark.cleaner.referenceTracking", true)) { + Some(new ContextCleaner(this)) + } else { + None + } + } + cleaner.foreach(_.start()) + postEnvironmentUpdate() + postApplicationStart() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration: Configuration = { @@ -371,6 +431,46 @@ class SparkContext( minSplits).map(pair => pair._2.toString) } + /** + * Read a directory of text files from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI. Each file is read as a single record and returned in a + * key-value pair, where the key is the path of each file, the value is the content of each file. + * + *

For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do `val rdd = sparkContext.wholeTextFile("hdfs://a-hdfs-path")`, + * + *

then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @note Small files are preferred, large file is also allowable, but may cause bad performance. + * + * @param minSplits A suggestion value of the minimal splitting number for input data. + */ + def wholeTextFiles(path: String, minSplits: Int = defaultMinSplits): RDD[(String, String)] = { + val job = new NewHadoopJob(hadoopConfiguration) + NewFileInputFormat.addInputPath(job, new Path(path)) + val updateConf = job.getConfiguration + new WholeTextFileRDD( + this, + classOf[WholeTextFileInputFormat], + classOf[String], + classOf[String], + updateConf, + minSplits) + } + /** * Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other * necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable), @@ -606,6 +706,9 @@ class SparkContext( def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = new UnionRDD(this, Seq(first) ++ rest) + /** Get an RDD that has no partitions or elements. */ + def emptyRDD[T: ClassTag] = new EmptyRDD[T](this) + // Methods for creating shared variables /** @@ -641,7 +744,11 @@ class SparkContext( * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T): Broadcast[T] = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T): Broadcast[T] = { + val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + cleaner.foreach(_.registerBroadcastForCleanup(bc)) + bc + } /** * Add a file to be downloaded with this Spark job on every node. @@ -665,10 +772,18 @@ class SparkContext( postEnvironmentUpdate() } + /** + * :: DeveloperApi :: + * Register a listener to receive up-calls from events that happen during execution. + */ + @DeveloperApi def addSparkListener(listener: SparkListener) { listenerBus.addListener(listener) } + /** The version of Spark on which this application is running. */ + def version = SparkContext.SPARK_VERSION + /** * Return a map from the slave to the max memory available for caching and the remaining * memory available for caching. @@ -693,10 +808,6 @@ class SparkContext( */ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap - def getStageInfo: Map[Stage, StageInfo] = { - dagScheduler.stageToInfos - } - /** * Return information about blocks stored in all of the slaves */ @@ -755,8 +866,7 @@ class SparkContext( /** * Unpersist an RDD from memory and/or disk storage */ - private[spark] def unpersistRDD(rdd: RDD[_], blocking: Boolean = true) { - val rddId = rdd.id + private[spark] def unpersistRDD(rddId: Int, blocking: Boolean = true) { env.blockManager.master.removeRdd(rddId, blocking) persistentRdds.remove(rddId) listenerBus.post(SparkListenerUnpersistRDD(rddId)) @@ -827,22 +937,24 @@ class SparkContext( /** Shut down the SparkContext. */ def stop() { + postApplicationEnd() ui.stop() - eventLogger.foreach(_.stop()) // Do this only if not stopped already - best case effort. // prevent NPE if stopped more than once. val dagSchedulerCopy = dagScheduler dagScheduler = null if (dagSchedulerCopy != null) { metadataCleaner.cancel() + cleaner.foreach(_.stop()) dagSchedulerCopy.stop() - listenerBus.stop() taskScheduler = null // TODO: Cache.stop()? env.stop() SparkEnv.set(null) ShuffleMapTask.clearCache() ResultTask.clearCache() + listenerBus.stop() + eventLogger.foreach(_.stop()) logInfo("Successfully stopped SparkContext") } else { logInfo("SparkContext already stopped") @@ -974,8 +1086,10 @@ class SparkContext( } /** + * :: DeveloperApi :: * Run a job that can return approximate results. */ + @DeveloperApi def runApproximateJob[T, U, R]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, @@ -993,6 +1107,7 @@ class SparkContext( /** * Submit a job for execution and return a FutureJob holding the result. */ + @Experimental def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, @@ -1029,6 +1144,16 @@ class SparkContext( dagScheduler.cancelAllJobs() } + /** Cancel a given job if it's scheduled or running */ + private[spark] def cancelJob(jobId: Int) { + dagScheduler.cancelJob(jobId) + } + + /** Cancel a given stage and all jobs associated with it */ + private[spark] def cancelStage(stageId: Int) { + dagScheduler.cancelStage(stageId) + } + /** * Clean a closure to make it ready to serialized and send to tasks * (removes unreferenced variables in $outer's, updates REPL variables) @@ -1068,6 +1193,16 @@ class SparkContext( /** Register a new RDD, returning its RDD ID */ private[spark] def newRddId(): Int = nextRddId.getAndIncrement() + /** Post the application start event */ + private def postApplicationStart() { + listenerBus.post(SparkListenerApplicationStart(appName, startTime, sparkUser)) + } + + /** Post the application end event */ + private def postApplicationEnd() { + listenerBus.post(SparkListenerApplicationEnd(System.currentTimeMillis)) + } + /** Post the environment update event once the task scheduler is ready */ private def postEnvironmentUpdate() { if (taskScheduler != null) { @@ -1093,6 +1228,8 @@ class SparkContext( */ object SparkContext extends Logging { + private[spark] val SPARK_VERSION = "1.0.0" + private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description" private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id" @@ -1251,8 +1388,8 @@ object SparkContext extends Logging { /** Creates a task scheduler based on a given master URL. Extracted for testing. */ private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = { - // Regular expression used for local[N] master format - val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r + // Regular expression used for local[N] and local[*] master formats + val LOCAL_N_REGEX = """local\[([0-9\*]+)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r // Regular expression for simulating a Spark cluster of [N, cores, memory] locally @@ -1275,8 +1412,11 @@ object SparkContext extends Logging { scheduler case LOCAL_N_REGEX(threads) => + def localCpuCount = Runtime.getRuntime.availableProcessors() + // local[*] estimates the number of cores on the machine; local[N] uses exactly N threads. + val threadCount = if (threads == "*") localCpuCount else threads.toInt val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true) - val backend = new LocalBackend(scheduler, threads.toInt) + val backend = new LocalBackend(scheduler, threadCount) scheduler.initialize(backend) scheduler diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 5ceac28fe7afb..915315ed74436 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -25,6 +25,7 @@ import scala.util.Properties import akka.actor._ import com.google.common.collect.MapMaker +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.metrics.MetricsSystem @@ -35,13 +36,18 @@ import org.apache.spark.storage._ import org.apache.spark.util.{AkkaUtils, Utils} /** + * :: DeveloperApi :: * Holds all the runtime environment objects for a running Spark instance (either master or worker), * including the serializer, Akka actor system, block manager, map output tracker, etc. Currently * Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these * objects needs to have the right SparkEnv set. You can get the current environment with * SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set. + * + * NOTE: This is not intended for external use. This is exposed for Shark and may be made private + * in a future release. */ -class SparkEnv private[spark] ( +@DeveloperApi +class SparkEnv ( val executorId: String, val actorSystem: ActorSystem, val serializer: Serializer, @@ -180,12 +186,24 @@ object SparkEnv extends Logging { } } + val mapOutputTracker = if (isDriver) { + new MapOutputTrackerMaster(conf) + } else { + new MapOutputTrackerWorker(conf) + } + + // Have to assign trackerActor after initialization as MapOutputTrackerActor + // requires the MapOutputTracker itself + mapOutputTracker.trackerActor = registerOrLookup( + "MapOutputTracker", + new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf) val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, - serializer, conf, securityManager) + serializer, conf, securityManager, mapOutputTracker) val connectionManager = blockManager.connectionManager @@ -193,17 +211,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - // Have to assign trackerActor after initialization as MapOutputTrackerActor - // requires the MapOutputTracker itself - val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf) - } else { - new MapOutputTracker(conf) - } - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) - val shuffleFetcher = instantiateClass[ShuffleFetcher]( "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher") diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index d34e47e8cac22..4351ed74b67fc 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -20,5 +20,5 @@ package org.apache.spark class SparkException(message: String, cause: Throwable) extends Exception(message, cause) { - def this(message: String) = this(message, null) + def this(message: String) = this(message, null) } diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index b92ea01a877f7..f6703986bdf11 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -42,7 +42,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) private val now = new Date() private val conf = new SerializableWritable(jobConf) - + private var jobID = 0 private var splitID = 0 private var attemptID = 0 @@ -58,8 +58,8 @@ class SparkHadoopWriter(@transient jobConf: JobConf) def preSetup() { setIDs(0, 0, 0) HadoopRDD.addLocalConfiguration("", 0, 0, 0, conf.value) - - val jCtxt = getJobContext() + + val jCtxt = getJobContext() getOutputCommitter().setupJob(jCtxt) } @@ -74,7 +74,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) val numfmt = NumberFormat.getInstance() numfmt.setMinimumIntegerDigits(5) numfmt.setGroupingUsed(false) - + val outputName = "part-" + numfmt.format(splitID) val path = FileOutputFormat.getOutputPath(conf.value) val fs: FileSystem = { @@ -85,7 +85,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } } - getOutputCommitter().setupTask(getTaskContext()) + getOutputCommitter().setupTask(getTaskContext()) writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL) } @@ -103,18 +103,18 @@ class SparkHadoopWriter(@transient jobConf: JobConf) def commit() { val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() + val cmtr = getOutputCommitter() if (cmtr.needsTaskCommit(taCtxt)) { try { cmtr.commitTask(taCtxt) logInfo (taID + ": Committed") } catch { - case e: IOException => { + case e: IOException => { logError("Error committing the output of task: " + taID.value, e) cmtr.abortTask(taCtxt) throw e } - } + } } else { logWarning ("No need to commit output of task: " + taID.value) } @@ -144,7 +144,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) } private def getJobContext(): JobContext = { - if (jobContext == null) { + if (jobContext == null) { jobContext = newJobContext(conf.value, jID.value) } jobContext @@ -175,7 +175,7 @@ object SparkHadoopWriter { val jobtrackerID = formatter.format(time) new JobID(jobtrackerID, id) } - + def createPathFromString(path: String, conf: JobConf): Path = { if (path == null) { throw new IllegalArgumentException("Output path is null") diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala index a2a871cbd3c31..5b14c4291d91a 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala @@ -44,12 +44,12 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg * configurable in the future. */ private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST), - null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, + null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS, new SparkSaslClientCallbackHandler(securityMgr)) /** * Used to initiate SASL handshake with server. - * @return response to challenge if needed + * @return response to challenge if needed */ def firstToken(): Array[Byte] = { synchronized { @@ -86,7 +86,7 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg } /** - * Disposes of any system resources or security-sensitive information the + * Disposes of any system resources or security-sensitive information the * SaslClient might be using. */ def dispose() { @@ -110,7 +110,7 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends CallbackHandler { - private val userName: String = + private val userName: String = SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) private val secretKey = securityMgr.getSecretKey() private val userPassword: Array[Char] = @@ -138,7 +138,7 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg rc.setText(rc.getDefaultText()) } case cb: RealmChoiceCallback => {} - case cb: Callback => throw + case cb: Callback => throw new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback") } } diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala index 11fcb2ae3a5c5..6161a6fb7ae85 100644 --- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala +++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala @@ -64,7 +64,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi } /** - * Disposes of any system resources or security-sensitive information the + * Disposes of any system resources or security-sensitive information the * SaslServer might be using. */ def dispose() { @@ -88,7 +88,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager) extends CallbackHandler { - private val userName: String = + private val userName: String = SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes()) override def handle(callbacks: Array[Callback]) { @@ -123,7 +123,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi ac.setAuthorizedID(authzid) } } - case cb: Callback => throw + case cb: Callback => throw new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback") } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index be53ca2968cfb..dc5a19ecd738e 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -19,8 +19,14 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +/** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + */ +@DeveloperApi class TaskContext( val stageId: Int, val partitionId: Int, diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index f1a753b6ab8a9..a3074916d13e7 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -17,29 +17,35 @@ package org.apache.spark +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId /** + * :: DeveloperApi :: * Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry * tasks several times for "ephemeral" failures, and only report back failures that require some * old stages to be resubmitted, such as shuffle map fetch failures. */ -private[spark] sealed trait TaskEndReason +@DeveloperApi +sealed trait TaskEndReason -private[spark] case object Success extends TaskEndReason +@DeveloperApi +case object Success extends TaskEndReason -private[spark] +@DeveloperApi case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it -private[spark] case class FetchFailed( +@DeveloperApi +case class FetchFailed( bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason -private[spark] case class ExceptionFailure( +@DeveloperApi +case class ExceptionFailure( className: String, description: String, stackTrace: Array[StackTraceElement], @@ -47,21 +53,28 @@ private[spark] case class ExceptionFailure( extends TaskEndReason /** + * :: DeveloperApi :: * The task finished successfully, but the result was lost from the executor's block manager before * it was fetched. */ -private[spark] case object TaskResultLost extends TaskEndReason +@DeveloperApi +case object TaskResultLost extends TaskEndReason -private[spark] case object TaskKilled extends TaskEndReason +@DeveloperApi +case object TaskKilled extends TaskEndReason /** + * :: DeveloperApi :: * The task failed because the executor that it was running on was lost. This may happen because * the task crashed the JVM. */ -private[spark] case object ExecutorLostFailure extends TaskEndReason +@DeveloperApi +case object ExecutorLostFailure extends TaskEndReason /** + * :: DeveloperApi :: * We don't know why the task ended -- for example, because of a ClassNotFound exception when * deserializing the task result. */ -private[spark] case object UnknownReason extends TaskEndReason +@DeveloperApi +case object UnknownReason extends TaskEndReason diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala new file mode 100644 index 0000000000000..f3f59e47c3e98 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.{File, FileInputStream, FileOutputStream} +import java.net.{URI, URL} +import java.util.jar.{JarEntry, JarOutputStream} + +import scala.collection.JavaConversions._ + +import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider} +import com.google.common.io.Files + +/** + * Utilities for tests. Included in main codebase since it's used by multiple + * projects. + * + * TODO: See if we can move this to the test codebase by specifying + * test dependencies between projects. + */ +private[spark] object TestUtils { + + /** + * Create a jar that defines classes with the given names. + * + * Note: if this is used during class loader tests, class names should be unique + * in order to avoid interference between tests. + */ + def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = { + val tempDir = Files.createTempDir() + val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value) + val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis())) + createJar(files, jarFile) + } + + + /** + * Create a jar file that contains this set of files. All files will be located at the root + * of the jar. + */ + def createJar(files: Seq[File], jarFile: File): URL = { + val jarFileStream = new FileOutputStream(jarFile) + val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest()) + + for (file <- files) { + val jarEntry = new JarEntry(file.getName) + jarStream.putNextEntry(jarEntry) + + val in = new FileInputStream(file) + val buffer = new Array[Byte](10240) + var nRead = 0 + while (nRead <= 0) { + nRead = in.read(buffer, 0, buffer.length) + jarStream.write(buffer, 0, nRead) + } + in.close() + } + jarStream.close() + jarFileStream.close() + + jarFile.toURI.toURL + } + + // Adapted from the JavaCompiler.java doc examples + private val SOURCE = JavaFileObject.Kind.SOURCE + private def createURI(name: String) = { + URI.create(s"string:///${name.replace(".", "/")}${SOURCE.extension}") + } + + private class JavaSourceFromString(val name: String, val code: String) + extends SimpleJavaFileObject(createURI(name), SOURCE) { + override def getCharContent(ignoreEncodingErrors: Boolean) = code + } + + /** Creates a compiled class with the given name. Class file will be placed in destDir. */ + def createCompiledClass(className: String, destDir: File, value: String = ""): File = { + val compiler = ToolProvider.getSystemJavaCompiler + val sourceFile = new JavaSourceFromString(className, + "public class " + className + " { @Override public String toString() { " + + "return \"" + value + "\";}}") + + // Calling this outputs a class file in pwd. It's easier to just rename the file than + // build a custom FileManager that controls the output location. + compiler.getTask(null, null, null, null, null, Seq(sourceFile)).call() + + val fileName = className + ".class" + val result = new File(fileName) + if (!result.exists()) throw new Exception("Compiled file not found: " + fileName) + val out = new File(destDir, fileName) + result.renameTo(out) + out + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixSVD.scala b/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java similarity index 68% rename from mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixSVD.scala rename to core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java index 319f82b449096..af01fb7cfbd04 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixSVD.scala +++ b/core/src/main/scala/org/apache/spark/annotation/AlphaComponent.java @@ -15,15 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.linalg +package org.apache.spark.annotation; -/** - * Class that represents the SV decomposition of a matrix - * - * @param U such that A = USV^T - * @param S such that A = USV^T - * @param V such that A = USV^T - */ -case class MatrixSVD(val U: SparseMatrix, - val S: SparseMatrix, - val V: SparseMatrix) +import java.lang.annotation.*; + +/** A new component of Spark which may have unstable API's. */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface AlphaComponent {} diff --git a/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java b/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java new file mode 100644 index 0000000000000..5d546e7a63985 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/DeveloperApi.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * A lower-level, unstable API intended for developers. + * + * Developer API's might change or be removed in minor versions of Spark. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface DeveloperApi {} diff --git a/core/src/main/scala/org/apache/spark/annotation/Experimental.java b/core/src/main/scala/org/apache/spark/annotation/Experimental.java new file mode 100644 index 0000000000000..306b1418d8d0a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/annotation/Experimental.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.annotation; + +import java.lang.annotation.*; + +/** + * An experimental user-facing API. + * + * Experimental API's might change or be removed in minor versions of Spark, or be adopted as + * first-class Spark API's. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER, + ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE}) +public @interface Experimental {} diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index f816bb43a5b44..537f410b0ca26 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -23,6 +23,7 @@ import scala.reflect.ClassTag import org.apache.spark.Partitioner import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD @@ -184,14 +185,26 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja def meanApprox(timeout: Long, confidence: JDouble): PartialResult[BoundedDouble] = srdd.meanApprox(timeout, confidence) - /** (Experimental) Approximate operation to return the mean within a timeout. */ + /** + * :: Experimental :: + * Approximate operation to return the mean within a timeout. + */ + @Experimental def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout) - /** (Experimental) Approximate operation to return the sum within a timeout. */ + /** + * :: Experimental :: + * Approximate operation to return the sum within a timeout. + */ + @Experimental def sumApprox(timeout: Long, confidence: JDouble): PartialResult[BoundedDouble] = srdd.sumApprox(timeout, confidence) - /** (Experimental) Approximate operation to return the sum within a timeout. */ + /** + * :: Experimental :: + * Approximate operation to return the sum within a timeout. + */ + @Experimental def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 9596dbaf75488..a41c7dbda2afc 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.api.java import java.util.{Comparator, List => JList} +import java.lang.{Iterable => JIterable} import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -26,11 +27,12 @@ import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job} +import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext.rddToPairRDDFunctions +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} @@ -200,16 +202,20 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does + * :: Experimental :: + * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ + @Experimental def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout).map(mapAsJavaMap) /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does + * :: Experimental :: + * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ + @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap) @@ -250,14 +256,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. */ - def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] = + def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JIterable[V]] = fromRDD(groupByResultToJava(rdd.groupByKey(partitioner))) /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. */ - def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] = + def groupByKey(numPartitions: Int): JavaPairRDD[K, JIterable[V]] = fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions))) /** @@ -367,7 +373,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. */ - def groupByKey(): JavaPairRDD[K, JList[V]] = + def groupByKey(): JavaPairRDD[K, JIterable[V]] = fromRDD(groupByResultToJava(rdd.groupByKey())) /** @@ -462,7 +468,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner) - : JavaPairRDD[K, (JList[V], JList[W])] = + : JavaPairRDD[K, (JIterable[V], JIterable[W])] = fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner))) /** @@ -470,14 +476,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], - partitioner: Partitioner): JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = + partitioner: Partitioner): JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner))) /** * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = + def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] = fromRDD(cogroupResultToJava(rdd.cogroup(other))) /** @@ -485,7 +491,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2))) /** @@ -493,7 +499,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * list of values for that key in `this` as well as `other`. */ def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int) - : JavaPairRDD[K, (JList[V], JList[W])] = + : JavaPairRDD[K, (JIterable[V], JIterable[W])] = fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions))) /** @@ -501,16 +507,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions))) /** Alias for cogroup. */ - def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] = + def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JIterable[V], JIterable[W])] = fromRDD(cogroupResultToJava(rdd.groupWith(other))) /** Alias for cogroup. */ def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2]) - : JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] = + : JavaPairRDD[K, (JIterable[V], JIterable[W1], JIterable[W2])] = fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2))) /** @@ -695,21 +701,22 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) object JavaPairRDD { private[spark] - def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Seq[T])]): RDD[(K, JList[T])] = { - rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList) + def groupByResultToJava[K: ClassTag, T](rdd: RDD[(K, Iterable[T])]): RDD[(K, JIterable[T])] = { + rddToPairRDDFunctions(rdd).mapValues(asJavaIterable) } private[spark] def cogroupResultToJava[K: ClassTag, V, W]( - rdd: RDD[(K, (Seq[V], Seq[W]))]): RDD[(K, (JList[V], JList[W]))] = { - rddToPairRDDFunctions(rdd).mapValues(x => (seqAsJavaList(x._1), seqAsJavaList(x._2))) + rdd: RDD[(K, (Iterable[V], Iterable[W]))]): RDD[(K, (JIterable[V], JIterable[W]))] = { + rddToPairRDDFunctions(rdd).mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2))) } private[spark] def cogroupResult2ToJava[K: ClassTag, V, W1, W2]( - rdd: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))]): RDD[(K, (JList[V], JList[W1], JList[W2]))] = { + rdd: RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))]) + : RDD[(K, (JIterable[V], JIterable[W1], JIterable[W2]))] = { rddToPairRDDFunctions(rdd) - .mapValues(x => (seqAsJavaList(x._1), seqAsJavaList(x._2), seqAsJavaList(x._3))) + .mapValues(x => (asJavaIterable(x._1), asJavaIterable(x._2), asJavaIterable(x._3))) } def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index e03b8e78d5f52..725c423a53e35 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -17,7 +17,8 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList} +import java.util.{Comparator, List => JList, Iterator => JIterator} +import java.lang.{Iterable => JIterable} import scala.collection.JavaConversions._ import scala.reflect.ClassTag @@ -26,6 +27,7 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark.{Partition, SparkContext, TaskContext} +import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _} @@ -203,7 +205,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ - def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = { + def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JIterable[T]] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[JList[T]] = fakeClassTag JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(fakeClassTag))) @@ -213,7 +215,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ - def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = { + def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JIterable[T]] = { implicit val ctagK: ClassTag[K] = fakeClassTag implicit val ctagV: ClassTag[JList[T]] = fakeClassTag JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[K]))) @@ -280,6 +282,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { new java.util.ArrayList(arr) } + /** + * Return an iterator that contains all of the elements in this RDD. + * + * The iterator will consume as much memory as the largest partition in this RDD. + */ + def toLocalIterator(): JIterator[T] = { + import scala.collection.JavaConversions._ + rdd.toLocalIterator + } + + /** * Return an array that contains all of the elements in this RDD. * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead @@ -331,16 +344,20 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def count(): Long = rdd.count() /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result + * :: Experimental :: + * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. */ + @Experimental def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] = rdd.countApprox(timeout, confidence) /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result + * :: Experimental :: + * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. */ + @Experimental def countApprox(timeout: Long): PartialResult[BoundedDouble] = rdd.countApprox(timeout) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index e531a57aced31..7fbefe1cb0fb1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -89,7 +89,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork */ def this(master: String, appName: String, sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment)) + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment, Map())) private[spark] val env = sc.env @@ -154,6 +154,46 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork */ def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits) + /** + * Read a directory of text files from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI. Each file is read as a single record and returned in a + * key-value pair, where the key is the path of each file, the value is the content of each file. + * + *

For example, if you have the following files: + * {{{ + * hdfs://a-hdfs-path/part-00000 + * hdfs://a-hdfs-path/part-00001 + * ... + * hdfs://a-hdfs-path/part-nnnnn + * }}} + * + * Do `JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")`, + * + *

then `rdd` contains + * {{{ + * (a-hdfs-path/part-00000, its content) + * (a-hdfs-path/part-00001, its content) + * ... + * (a-hdfs-path/part-nnnnn, its content) + * }}} + * + * @note Small files are preferred, large file is also allowable, but may cause bad performance. + * + * @param minSplits A suggestion value of the minimal splitting number for input data. + */ + def wholeTextFiles(path: String, minSplits: Int): JavaPairRDD[String, String] = + new JavaPairRDD(sc.wholeTextFiles(path, minSplits)) + + /** + * Read a directory of text files from HDFS, a local file system (available on all nodes), or any + * Hadoop-supported file system URI. Each file is read as a single record and returned in a + * key-value pair, where the key is the path of each file, the value is the content of each file. + * + * @see `wholeTextFiles(path: String, minSplits: Int)`. + */ + def wholeTextFiles(path: String): JavaPairRDD[String, String] = + new JavaPairRDD(sc.wholeTextFiles(path)) + /** Get an RDD for a Hadoop SequenceFile with given key and value types. * * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index ecbf18849ad48..22810cb1c662d 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.api.java import com.google.common.base.Optional -object JavaUtils { +private[spark] object JavaUtils { def optionToOptional[T](option: Option[T]): Optional[T] = option match { case Some(value) => Optional.of(value) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b67286a4e3b75..32f1100406d74 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ +import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ @@ -206,6 +207,7 @@ private object SpecialLengths { } private[spark] object PythonRDD { + val UTF8 = Charset.forName("UTF-8") def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { @@ -266,7 +268,7 @@ private[spark] object PythonRDD { } def writeUTF(str: String, dataOut: DataOutputStream) { - val bytes = str.getBytes("UTF-8") + val bytes = str.getBytes(UTF8) dataOut.writeInt(bytes.length) dataOut.write(bytes) } @@ -286,7 +288,7 @@ private[spark] object PythonRDD { private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { - override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") + override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8) } /** diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index e3c3a12d16f2a..738a3b1bed7f3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -18,9 +18,8 @@ package org.apache.spark.broadcast import java.io.Serializable -import java.util.concurrent.atomic.AtomicLong -import org.apache.spark._ +import org.apache.spark.SparkException /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable @@ -29,7 +28,8 @@ import org.apache.spark._ * attempts to distribute broadcast variables using efficient broadcast algorithms to reduce * communication cost. * - * Broadcast variables are created from a variable `v` by calling [[SparkContext#broadcast]]. + * Broadcast variables are created from a variable `v` by calling + * [[org.apache.spark.SparkContext#broadcast]]. * The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the * `value` method. The interpreter session below shows this: * @@ -51,49 +51,80 @@ import org.apache.spark._ * @tparam T Type of the data contained in the broadcast variable. */ abstract class Broadcast[T](val id: Long) extends Serializable { - def value: T - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. + /** + * Flag signifying whether the broadcast variable is valid + * (that is, not already destroyed) or not. + */ + @volatile private var _isValid = true - override def toString = "Broadcast(" + id + ")" -} - -private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) - extends Logging with Serializable { - - private var initialized = false - private var broadcastFactory: BroadcastFactory = null - - initialize() - - // Called by SparkContext or Executor before using Broadcast - private def initialize() { - synchronized { - if (!initialized) { - val broadcastFactoryClass = conf.get( - "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - - broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + /** Get the broadcasted value. */ + def value: T = { + assertValid() + getValue() + } - // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf, securityManager) + /** + * Asynchronously delete cached copies of this broadcast on the executors. + * If the broadcast is used after this is called, it will need to be re-sent to each executor. + */ + def unpersist() { + unpersist(blocking = false) + } - initialized = true - } - } + /** + * Delete cached copies of this broadcast on the executors. If the broadcast is used after + * this is called, it will need to be re-sent to each executor. + * @param blocking Whether to block until unpersisting has completed + */ + def unpersist(blocking: Boolean) { + assertValid() + doUnpersist(blocking) } - def stop() { - broadcastFactory.stop() + /** + * Destroy all data and metadata related to this broadcast variable. Use this with caution; + * once a broadcast variable has been destroyed, it cannot be used again. + */ + private[spark] def destroy(blocking: Boolean) { + assertValid() + _isValid = false + doDestroy(blocking) } - private val nextBroadcastId = new AtomicLong(0) + /** + * Whether this Broadcast is actually usable. This should be false once persisted state is + * removed from the driver. + */ + private[spark] def isValid: Boolean = { + _isValid + } - def newBroadcast[T](value_ : T, isLocal: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + /** + * Actually get the broadcasted value. Concrete implementations of Broadcast class must + * define their own way to get the value. + */ + private[spark] def getValue(): T + + /** + * Actually unpersist the broadcasted value on the executors. Concrete implementations of + * Broadcast class must define their own logic to unpersist their own data. + */ + private[spark] def doUnpersist(blocking: Boolean) + + /** + * Actually destroy all data and metadata related to this broadcast variable. + * Implementation of Broadcast class must define their own logic to destroy their own + * state. + */ + private[spark] def doDestroy(blocking: Boolean) + + /** Check if this broadcast is valid. If not valid, exception is thrown. */ + private[spark] def assertValid() { + if (!_isValid) { + throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) + } + } - def isDriver = _isDriver + override def toString = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 6beecaeced5be..8c8ce9b1691ac 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -16,18 +16,22 @@ */ package org.apache.spark.broadcast -import org.apache.spark.SecurityManager +import org.apache.spark.SecurityManager import org.apache.spark.SparkConf +import org.apache.spark.annotation.DeveloperApi /** - * An interface for all the broadcast implementations in Spark (to allow + * :: DeveloperApi :: + * An interface for all the broadcast implementations in Spark (to allow * multiple broadcast implementations). SparkContext uses a user-specified * BroadcastFactory implementation to instantiate a particular broadcast for the * entire Spark job. */ +@DeveloperApi trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala new file mode 100644 index 0000000000000..cf62aca4d45e8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark._ + +private[spark] class BroadcastManager( + val isDriver: Boolean, + conf: SparkConf, + securityManager: SecurityManager) + extends Logging { + + private var initialized = false + private var broadcastFactory: BroadcastFactory = null + + initialize() + + // Called by SparkContext or Executor before using Broadcast + private def initialize() { + synchronized { + if (!initialized) { + val broadcastFactoryClass = + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + + broadcastFactory = + Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isDriver, conf, securityManager) + + initialized = true + } + } + } + + def stop() { + broadcastFactory.stop() + } + + private val nextBroadcastId = new AtomicLong(0) + + def newBroadcast[T](value_ : T, isLocal: Boolean) = { + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + } + + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + broadcastFactory.unbroadcast(id, removeFromDriver, blocking) + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index e8eb04bb10469..29372f16f2cac 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -17,34 +17,64 @@ package org.apache.spark.broadcast -import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} +import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} +import java.io.{BufferedInputStream, BufferedOutputStream} import java.net.{URL, URLConnection, URI} import java.util.concurrent.TimeUnit -import it.unimi.dsi.fastutil.io.FastBufferedInputStream -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - -import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} +import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} +/** + * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server + * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a + * task) is deserialized in the executor, the broadcasted data is fetched from the driver + * (through a HTTP server running at the driver) and stored in the BlockManager of the + * executor to speed up future accesses. + */ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + def getValue = value_ - def blockId = BroadcastBlockId(id) + val blockId = BroadcastBlockId(id) + /* + * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster + * does not need to be told about this block as not only need to know about this data block. + */ HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } if (!isLocal) { HttpBroadcast.write(id, value_) } - // Called by JVM when deserializing an object + /** + * Remove all persisted state associated with this HTTP broadcast on the executors. + */ + def doUnpersist(blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) + } + + /** + * Remove all persisted state associated with this HTTP broadcast on the executors and driver. + */ + def doDestroy(blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) + } + + /** Used by the JVM when serializing this object. */ + private def writeObject(out: ObjectOutputStream) { + assertValid() + out.defaultWriteObject() + } + + /** Used by the JVM when deserializing this object. */ private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { @@ -54,7 +84,13 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + /* + * We cache broadcast data in the BlockManager so that subsequent tasks using it + * do not need to re-fetch. This data is only used locally and no other node + * needs to fetch this block, so we don't notify the master. + */ + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -63,23 +99,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } } -/** - * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. - */ -class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) - - def stop() { HttpBroadcast.stop() } -} - -private object HttpBroadcast extends Logging { +private[spark] object HttpBroadcast extends Logging { private var initialized = false - private var broadcastDir: File = null private var compress: Boolean = false private var bufferSize: Int = 65536 @@ -89,11 +110,9 @@ private object HttpBroadcast extends Logging { // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist private val files = new TimeStampedHashSet[String] - private var cleaner: MetadataCleaner = null - private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt - private var compressionCodec: CompressionCodec = null + private var cleaner: MetadataCleaner = null def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { synchronized { @@ -136,13 +155,15 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def write(id: Long, value: Any) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) } else { - new FastBufferedOutputStream(new FileOutputStream(file), bufferSize) + new BufferedOutputStream(new FileOutputStream(file), bufferSize) } } val ser = SparkEnv.get.serializer.newInstance() @@ -160,7 +181,7 @@ private object HttpBroadcast extends Logging { if (securityManager.isAuthenticationEnabled()) { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL().openConnection() + uc = newuri.toURL.openConnection() uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") @@ -169,11 +190,11 @@ private object HttpBroadcast extends Logging { val in = { uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream(); + val inputStream = uc.getInputStream if (compress) { compressionCodec.compressedInputStream(inputStream) } else { - new FastBufferedInputStream(inputStream, bufferSize) + new BufferedInputStream(inputStream, bufferSize) } } val ser = SparkEnv.get.serializer.newInstance() @@ -183,20 +204,48 @@ private object HttpBroadcast extends Logging { obj } - def cleanup(cleanupTime: Long) { + /** + * Remove all persisted blocks associated with this HTTP broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver + * and delete the associated broadcast file. + */ + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + if (removeFromDriver) { + val file = getFile(id) + files.remove(file.toString) + deleteBroadcastFile(file) + } + } + + /** + * Periodically clean up old broadcasts by removing the associated map entries and + * deleting the associated files. + */ + private def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + iterator.remove() + deleteBroadcastFile(new File(file.toString)) + } + } + } + + private def deleteBroadcastFile(file: File) { + try { + if (file.exists) { + if (file.delete()) { + logInfo("Deleted broadcast file: %s".format(file)) + } else { + logWarning("Could not delete broadcast file: %s".format(file)) } } + } catch { + case e: Exception => + logError("Exception while deleting broadcast file: %s".format(file), e) } } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala new file mode 100644 index 0000000000000..e3f6cdc6154dd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a + * HTTP server as the broadcast mechanism. Refer to + * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism. + */ +class HttpBroadcastFactory extends BroadcastFactory { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new HttpBroadcast[T](value_, isLocal, id) + + def stop() { HttpBroadcast.stop() } + + /** + * Remove all persisted state associated with the HTTP broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver + * @param blocking Whether to block until unbroadcasted + */ + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver, blocking) + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 2595c15104e87..2659274c5e98e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,24 +17,43 @@ package org.apache.spark.broadcast -import java.io._ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} import scala.math import scala.util.Random -import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.Utils +/** + * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like + * protocol to do a distributed transfer of the broadcasted data to the executors. + * The mechanism is as follows. The driver divides the serializes the broadcasted data, + * divides it into smaller chunks, and stores them in the BlockManager of the driver. + * These chunks are reported to the BlockManagerMaster so that all the executors can + * learn the location of those chunks. The first time the broadcast variable (sent as + * part of task) is deserialized at a executor, all the chunks are fetched using + * the BlockManager. When all the chunks are fetched (initially from the driver's + * BlockManager), they are combined and deserialized to recreate the broadcasted data. + * However, the chunks are also stored in the BlockManager and reported to the + * BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns + * multiple locations for each chunk. Hence, subsequent fetches of each chunk will be + * made to other executors who already have those chunks, resulting in a distributed + * fetching. This prevents the driver from being the bottleneck in sending out multiple + * copies of the broadcast data (one per executor) as done by the + * [[org.apache.spark.broadcast.HttpBroadcast]]. + */ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) -extends Broadcast[T](id) with Logging with Serializable { + extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + def getValue = value_ - def broadcastId = BroadcastBlockId(id) + val broadcastId = BroadcastBlockId(id) TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } @transient var arrayOfBlocks: Array[TorrentBlock] = null @@ -46,32 +65,52 @@ extends Broadcast[T](id) with Logging with Serializable { sendBroadcast() } - def sendBroadcast() { - var tInfo = TorrentBroadcast.blockifyObject(value_) + /** + * Remove all persisted state associated with this Torrent broadcast on the executors. + */ + def doUnpersist(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) + } + + /** + * Remove all persisted state associated with this Torrent broadcast on the executors + * and driver. + */ + def doDestroy(blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) + } + def sendBroadcast() { + val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes hasBlocks = tInfo.totalBlocks // Store meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) } // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + val pieceId = BroadcastBlockId(id, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) } } } - // Called by JVM when deserializing an object + /** Used by the JVM when serializing this object. */ + private def writeObject(out: ObjectOutputStream) { + assertValid() + out.defaultWriteObject() + } + + /** Used by the JVM when deserializing this object. */ private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { @@ -86,18 +125,22 @@ extends Broadcast[T](id) with Logging with Serializable { // Initialize @transient variables that will receive garbage values from the master. resetWorkerVariables() - if (receiveBroadcast(id)) { + if (receiveBroadcast()) { value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - // Store the merged copy in cache so that the next worker doesn't need to rebuild it. - // This creates a tradeoff between memory usage and latency. - // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. + * This creates a trade-off between memory usage and latency. Storing copy doubles + * the memory footprint; not storing doubles deserialization cost. Also, + * this does not need to be reported to BlockManagerMaster since other executors + * does not need to access this block (they only need to fetch the chunks, + * which are reported). + */ SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() - } else { + } else { logError("Reading broadcast variable " + id + " failed") } @@ -114,9 +157,10 @@ extends Broadcast[T](id) with Logging with Serializable { hasBlocks = 0 } - def receiveBroadcast(variableID: Long): Boolean = { - // Receive meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + def receiveBroadcast(): Boolean = { + // Receive meta-info about the size of broadcast data, + // the number of chunks it is divided into, etc. + val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -138,17 +182,21 @@ extends Broadcast[T](id) with Logging with Serializable { return false } - // Receive actual blocks + /* + * Fetch actual chunks of data. Note that all these chunks are stored in + * the BlockManager and reported to the master, so that other executors + * can find out and pull the chunks from this executor. + */ val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + val pieceId = BroadcastBlockId(id, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] hasBlocks += 1 SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) case None => throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) @@ -156,16 +204,16 @@ extends Broadcast[T](id) with Logging with Serializable { } } - (hasBlocks == totalBlocks) + hasBlocks == totalBlocks } } -private object TorrentBroadcast -extends Logging { - +private[spark] object TorrentBroadcast extends Logging { + private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null + def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { @@ -179,39 +227,37 @@ extends Logging { initialized = false } - lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) val bais = new ByteArrayInputStream(byteArray) - var blockNum = (byteArray.length / BLOCK_SIZE) + var blockNum = byteArray.length / BLOCK_SIZE if (byteArray.length % BLOCK_SIZE != 0) { blockNum += 1 } - var retVal = new Array[TorrentBlock](blockNum) - var blockID = 0 + val blocks = new Array[TorrentBlock](blockNum) + var blockId = 0 for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + val tempByteArray = new Array[Byte](thisBlockSize) + bais.read(tempByteArray, 0, thisBlockSize) - retVal(blockID) = new TorrentBlock(blockID, tempByteArray) - blockID += 1 + blocks(blockId) = new TorrentBlock(blockId, tempByteArray) + blockId += 1 } bais.close() - val tInfo = TorrentInfo(retVal, blockNum, byteArray.length) - tInfo.hasBlocks = blockNum - - tInfo + val info = TorrentInfo(blocks, blockNum, byteArray.length) + info.hasBlocks = blockNum + info } - def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, - totalBlocks: Int): T = { + def unBlockifyObject[T]( + arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { val retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, @@ -220,6 +266,13 @@ extends Logging { Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) } + /** + * Remove all persisted blocks associated with this torrent broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver. + */ + def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + } } private[spark] case class TorrentBlock( @@ -228,25 +281,10 @@ private[spark] case class TorrentBlock( extends Serializable private[spark] case class TorrentInfo( - @transient arrayOfBlocks : Array[TorrentBlock], + @transient arrayOfBlocks: Array[TorrentBlock], totalBlocks: Int, totalBytes: Int) extends Serializable { @transient var hasBlocks = 0 } - -/** - * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. - */ -class TorrentBroadcastFactory extends BroadcastFactory { - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - TorrentBroadcast.initialize(isDriver, conf) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TorrentBroadcast[T](value_, isLocal, id) - - def stop() { TorrentBroadcast.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala new file mode 100644 index 0000000000000..d216b58718148 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like + * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to + * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details. + */ +class TorrentBroadcastFactory extends BroadcastFactory { + + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } + + /** + * Remove all persisted state associated with the torrent broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + * @param blocking Whether to block until unbroadcasted + */ + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver, blocking) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index 15fa8a7679874..86305d2ea8a09 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -17,8 +17,6 @@ package org.apache.spark.deploy -import org.apache.spark.scheduler.EventLoggingInfo - private[spark] class ApplicationDescription( val name: String, val maxCores: Option[Int], @@ -26,7 +24,7 @@ private[spark] class ApplicationDescription( val command: Command, val sparkHome: Option[String], var appUiUrl: String, - val eventLogInfo: Option[EventLoggingInfo] = None) + val eventLogDir: Option[String] = None) extends Serializable { val user = System.getProperty("user.name", "") diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index c07838f798799..5da9615c9e9af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -43,7 +43,7 @@ private[spark] class ClientArguments(args: Array[String]) { // kill parameters var driverId: String = "" - + parse(args.toList) def parse(args: List[String]): Unit = args match { diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 83ce14a0a806a..a7368f9f3dfbe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -86,6 +86,10 @@ private[deploy] object DeployMessages { case class KillDriver(driverId: String) extends DeployMessage + // Worker internal + + case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + // AppClient to Master case class RegisterApplication(appDescription: ApplicationDescription) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1fa799190409f..e05fbfe321495 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -79,20 +79,23 @@ object SparkSubmit { printErrorAndExit("master must start with yarn, mesos, spark, or local") } - // Because "yarn-standalone" and "yarn-client" encapsulate both the master + // Because "yarn-cluster" and "yarn-client" encapsulate both the master // and deploy mode, we have some logic to infer the master and deploy mode // from each other if only one is specified, or exit early if they are at odds. - if (appArgs.deployMode == null && appArgs.master == "yarn-standalone") { + if (appArgs.deployMode == null && + (appArgs.master == "yarn-standalone" || appArgs.master == "yarn-cluster")) { appArgs.deployMode = "cluster" } if (appArgs.deployMode == "cluster" && appArgs.master == "yarn-client") { printErrorAndExit("Deploy mode \"cluster\" and master \"yarn-client\" are not compatible") } - if (appArgs.deployMode == "client" && appArgs.master == "yarn-standalone") { - printErrorAndExit("Deploy mode \"client\" and master \"yarn-standalone\" are not compatible") + if (appArgs.deployMode == "client" && + (appArgs.master == "yarn-standalone" || appArgs.master == "yarn-cluster")) { + printErrorAndExit("Deploy mode \"client\" and master \"" + appArgs.master + + "\" are not compatible") } if (appArgs.deployMode == "cluster" && appArgs.master.startsWith("yarn")) { - appArgs.master = "yarn-standalone" + appArgs.master = "yarn-cluster" } if (appArgs.deployMode != "cluster" && appArgs.master.startsWith("yarn")) { appArgs.master = "yarn-client" diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 9c8f54ea6f77a..834b3df2f164b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -171,7 +171,7 @@ private[spark] class SparkSubmitArguments(args: Array[String]) { outStream.println("Unknown/unsupported param " + unknownParam) } outStream.println( - """Usage: spark-submit [options] + """Usage: spark-submit [options] |Options: | --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local. | --deploy-mode DEPLOY_MODE Mode to deploy the app in, either 'client' or 'cluster'. diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala new file mode 100644 index 0000000000000..180c853ce3096 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.ui.{WebUIPage, UIUtils} + +private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { + + def render(request: HttpServletRequest): Seq[Node] = { + val appRows = parent.appIdToInfo.values.toSeq.sortBy { app => -app.lastUpdated } + val appTable = UIUtils.listingTable(appHeader, appRow, appRows) + val content = +

+
+
    +
  • Event Log Location: {parent.baseLogDir}
  • +
+ { + if (parent.appIdToInfo.size > 0) { +

+ Showing {parent.appIdToInfo.size}/{parent.getNumApplications} + Completed Application{if (parent.getNumApplications > 1) "s" else ""} +

++ + appTable + } else { +

No Completed Applications Found

+ } + } +
+
+ UIUtils.basicSparkPage(content, "History Server") + } + + private val appHeader = Seq( + "App Name", + "Started", + "Completed", + "Duration", + "Spark User", + "Log Directory", + "Last Updated") + + private def appRow(info: ApplicationHistoryInfo): Seq[Node] = { + val appName = if (info.started) info.name else info.logDirPath.getName + val uiAddress = parent.getAddress + info.ui.basePath + val startTime = if (info.started) UIUtils.formatDate(info.startTime) else "Not started" + val endTime = if (info.completed) UIUtils.formatDate(info.endTime) else "Not completed" + val difference = if (info.started && info.completed) info.endTime - info.startTime else -1L + val duration = if (difference > 0) UIUtils.formatDuration(difference) else "---" + val sparkUser = if (info.started) info.sparkUser else "Unknown user" + val logDirectory = info.logDirPath.getName + val lastUpdated = UIUtils.formatDate(info.lastUpdated) + + {appName} + {startTime} + {endTime} + {duration} + {sparkUser} + {logDirectory} + {lastUpdated} + + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala new file mode 100644 index 0000000000000..cf64700f9098c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import scala.collection.mutable + +import org.apache.hadoop.fs.{FileStatus, Path} + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.scheduler._ +import org.apache.spark.ui.{WebUI, SparkUI} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.util.Utils + +/** + * A web server that renders SparkUIs of completed applications. + * + * For the standalone mode, MasterWebUI already achieves this functionality. Thus, the + * main use case of the HistoryServer is in other deploy modes (e.g. Yarn or Mesos). + * + * The logging directory structure is as follows: Within the given base directory, each + * application's event logs are maintained in the application's own sub-directory. This + * is the same structure as maintained in the event log write code path in + * EventLoggingListener. + * + * @param baseLogDir The base directory in which event logs are found + */ +class HistoryServer( + val baseLogDir: String, + securityManager: SecurityManager, + conf: SparkConf) + extends WebUI(securityManager, HistoryServer.WEB_UI_PORT, conf) with Logging { + + import HistoryServer._ + + private val fileSystem = Utils.getHadoopFileSystem(baseLogDir) + private val localHost = Utils.localHostName() + private val publicHost = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHost) + + // A timestamp of when the disk was last accessed to check for log updates + private var lastLogCheckTime = -1L + + // Number of completed applications found in this directory + private var numCompletedApplications = 0 + + @volatile private var stopped = false + + /** + * A background thread that periodically checks for event log updates on disk. + * + * If a log check is invoked manually in the middle of a period, this thread re-adjusts the + * time at which it performs the next log check to maintain the same period as before. + * + * TODO: Add a mechanism to update manually. + */ + private val logCheckingThread = new Thread { + override def run() { + while (!stopped) { + val now = System.currentTimeMillis + if (now - lastLogCheckTime > UPDATE_INTERVAL_MS) { + checkForLogs() + Thread.sleep(UPDATE_INTERVAL_MS) + } else { + // If the user has manually checked for logs recently, wait until + // UPDATE_INTERVAL_MS after the last check time + Thread.sleep(lastLogCheckTime + UPDATE_INTERVAL_MS - now) + } + } + } + } + + // A mapping of application ID to its history information, which includes the rendered UI + val appIdToInfo = mutable.HashMap[String, ApplicationHistoryInfo]() + + initialize() + + /** + * Initialize the history server. + * + * This starts a background thread that periodically synchronizes information displayed on + * this UI with the event logs in the provided base directory. + */ + def initialize() { + attachPage(new HistoryPage(this)) + attachHandler(createStaticHandler(STATIC_RESOURCE_DIR, "/static")) + logCheckingThread.start() + } + + /** + * Check for any updates to event logs in the base directory. This is only effective once + * the server has been bound. + * + * If a new completed application is found, the server renders the associated SparkUI + * from the application's event logs, attaches this UI to itself, and stores metadata + * information for this application. + * + * If the logs for an existing completed application are no longer found, the server + * removes all associated information and detaches the SparkUI. + */ + def checkForLogs() = synchronized { + if (serverInfo.isDefined) { + lastLogCheckTime = System.currentTimeMillis + logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTime)) + try { + val logStatus = fileSystem.listStatus(new Path(baseLogDir)) + val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]() + val logInfos = logDirs + .sortBy { dir => getModificationTime(dir) } + .map { dir => (dir, EventLoggingListener.parseLoggingInfo(dir.getPath, fileSystem)) } + .filter { case (dir, info) => info.applicationComplete } + + // Logging information for applications that should be retained + val retainedLogInfos = logInfos.takeRight(RETAINED_APPLICATIONS) + val retainedAppIds = retainedLogInfos.map { case (dir, _) => dir.getPath.getName } + + // Remove any applications that should no longer be retained + appIdToInfo.foreach { case (appId, info) => + if (!retainedAppIds.contains(appId)) { + detachSparkUI(info.ui) + appIdToInfo.remove(appId) + } + } + + // Render the application's UI if it is not already there + retainedLogInfos.foreach { case (dir, info) => + val appId = dir.getPath.getName + if (!appIdToInfo.contains(appId)) { + renderSparkUI(dir, info) + } + } + + // Track the total number of completed applications observed this round + numCompletedApplications = logInfos.size + + } catch { + case t: Throwable => logError("Exception in checking for event log updates", t) + } + } else { + logWarning("Attempted to check for event log updates before binding the server.") + } + } + + /** + * Render a new SparkUI from the event logs if the associated application is completed. + * + * HistoryServer looks for a special file that indicates application completion in the given + * directory. If this file exists, the associated application is regarded to be completed, in + * which case the server proceeds to render the SparkUI. Otherwise, the server does nothing. + */ + private def renderSparkUI(logDir: FileStatus, logInfo: EventLoggingInfo) { + val path = logDir.getPath + val appId = path.getName + val replayBus = new ReplayListenerBus(logInfo.logPaths, fileSystem, logInfo.compressionCodec) + val appListener = new ApplicationEventListener + replayBus.addListener(appListener) + val ui = new SparkUI(conf, replayBus, appId, "/history/" + appId) + + // Do not call ui.bind() to avoid creating a new server for each application + replayBus.replay() + if (appListener.applicationStarted) { + attachSparkUI(ui) + val appName = appListener.appName + val sparkUser = appListener.sparkUser + val startTime = appListener.startTime + val endTime = appListener.endTime + val lastUpdated = getModificationTime(logDir) + ui.setAppName(appName + " (completed)") + appIdToInfo(appId) = ApplicationHistoryInfo(appId, appName, startTime, endTime, + lastUpdated, sparkUser, path, ui) + } + } + + /** Stop the server and close the file system. */ + override def stop() { + super.stop() + stopped = true + fileSystem.close() + } + + /** Attach a reconstructed UI to this server. Only valid after bind(). */ + private def attachSparkUI(ui: SparkUI) { + assert(serverInfo.isDefined, "HistoryServer must be bound before attaching SparkUIs") + ui.getHandlers.foreach(attachHandler) + } + + /** Detach a reconstructed UI from this server. Only valid after bind(). */ + private def detachSparkUI(ui: SparkUI) { + assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") + ui.getHandlers.foreach(detachHandler) + } + + /** Return the address of this server. */ + def getAddress: String = "http://" + publicHost + ":" + boundPort + + /** Return the number of completed applications found, whether or not the UI is rendered. */ + def getNumApplications: Int = numCompletedApplications + + /** Return when this directory was last modified. */ + private def getModificationTime(dir: FileStatus): Long = { + try { + val logFiles = fileSystem.listStatus(dir.getPath) + if (logFiles != null && !logFiles.isEmpty) { + logFiles.map(_.getModificationTime).max + } else { + dir.getModificationTime + } + } catch { + case t: Throwable => + logError("Exception in accessing modification time of %s".format(dir.getPath), t) + -1L + } + } +} + +/** + * The recommended way of starting and stopping a HistoryServer is through the scripts + * start-history-server.sh and stop-history-server.sh. The path to a base log directory + * is must be specified, while the requested UI port is optional. For example: + * + * ./sbin/spark-history-server.sh /tmp/spark-events + * ./sbin/spark-history-server.sh hdfs://1.2.3.4:9000/spark-events + * + * This launches the HistoryServer as a Spark daemon. + */ +object HistoryServer { + private val conf = new SparkConf + + // Interval between each check for event log updates + val UPDATE_INTERVAL_MS = conf.getInt("spark.history.updateInterval", 10) * 1000 + + // How many applications to retain + val RETAINED_APPLICATIONS = conf.getInt("spark.history.retainedApplications", 250) + + // The port to which the web UI is bound + val WEB_UI_PORT = conf.getInt("spark.history.ui.port", 18080) + + val STATIC_RESOURCE_DIR = SparkUI.STATIC_RESOURCE_DIR + + def main(argStrings: Array[String]) { + val args = new HistoryServerArguments(argStrings) + val securityManager = new SecurityManager(conf) + val server = new HistoryServer(args.logDir, securityManager, conf) + server.bind() + + // Wait until the end of the world... or if the HistoryServer process is manually stopped + while(true) { Thread.sleep(Int.MaxValue) } + server.stop() + } +} + + +private[spark] case class ApplicationHistoryInfo( + id: String, + name: String, + startTime: Long, + endTime: Long, + lastUpdated: Long, + sparkUser: String, + logDirPath: Path, + ui: SparkUI) { + def started = startTime != -1 + def completed = endTime != -1 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala new file mode 100644 index 0000000000000..943c061743dbd --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.net.URI + +import org.apache.hadoop.fs.Path + +import org.apache.spark.util.Utils + +/** + * Command-line parser for the master. + */ +private[spark] class HistoryServerArguments(args: Array[String]) { + var logDir = "" + + parse(args.toList) + + private def parse(args: List[String]): Unit = { + args match { + case ("--dir" | "-d") :: value :: tail => + logDir = value + parse(tail) + + case ("--help" | "-h") :: tail => + printUsageAndExit(0) + + case Nil => + + case _ => + printUsageAndExit(1) + } + validateLogDir() + } + + private def validateLogDir() { + if (logDir == "") { + System.err.println("Logging directory must be specified.") + printUsageAndExit(1) + } + val fileSystem = Utils.getHadoopFileSystem(new URI(logDir)) + val path = new Path(logDir) + if (!fileSystem.exists(path)) { + System.err.println("Logging directory specified does not exist: %s".format(logDir)) + printUsageAndExit(1) + } + if (!fileSystem.getFileStatus(path).isDir) { + System.err.println("Logging directory specified is not a directory: %s".format(logDir)) + printUsageAndExit(1) + } + } + + private def printUsageAndExit(exitCode: Int) { + System.err.println( + "Usage: HistoryServer [options]\n" + + "\n" + + "Options:\n" + + " -d DIR, --dir DIR Location of event log files") + System.exit(exitCode) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 95bd62e88db2b..6c58e741df001 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -29,6 +29,7 @@ import akka.actor._ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension +import org.apache.hadoop.fs.FileSystem import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState} @@ -37,7 +38,7 @@ import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.scheduler.ReplayListenerBus +import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI import org.apache.spark.util.{AkkaUtils, Utils} @@ -45,7 +46,8 @@ private[spark] class Master( host: String, port: Int, webUiPort: Int, - val securityMgr: SecurityManager) extends Actor with Logging { + val securityMgr: SecurityManager) + extends Actor with Logging { import context.dispatcher // to use Akka's scheduler.schedule() @@ -71,6 +73,7 @@ private[spark] class Master( var nextAppNumber = 0 val appIdToUI = new HashMap[String, SparkUI] + val fileSystemsUsed = new HashSet[FileSystem] val drivers = new HashSet[DriverInfo] val completedDrivers = new ArrayBuffer[DriverInfo] @@ -149,6 +152,7 @@ private[spark] class Master( override def postStop() { webUi.stop() + fileSystemsUsed.foreach(_.close()) masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() @@ -621,7 +625,7 @@ private[spark] class Master( if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { - appIdToUI.remove(a.id).foreach { ui => webUi.detachUI(ui) } + appIdToUI.remove(a.id).foreach { ui => webUi.detachSparkUI(ui) } applicationMetricsSystem.removeSource(a.appSource) }) completedApps.trimStart(toRemove) @@ -630,11 +634,7 @@ private[spark] class Master( waitingApps -= app // If application events are logged, use them to rebuild the UI - startPersistedSparkUI(app).map { ui => - app.desc.appUiUrl = ui.basePath - appIdToUI(app.id) = ui - webUi.attachUI(ui) - }.getOrElse { + if (!rebuildSparkUI(app)) { // Avoid broken links if the UI is not reconstructed app.desc.appUiUrl = "" } @@ -654,30 +654,34 @@ private[spark] class Master( } /** - * Start a new SparkUI rendered from persisted storage. If this is unsuccessful for any reason, - * return None. Otherwise return the reconstructed UI. + * Rebuild a new SparkUI from the given application's event logs. + * Return whether this is successful. */ - def startPersistedSparkUI(app: ApplicationInfo): Option[SparkUI] = { + def rebuildSparkUI(app: ApplicationInfo): Boolean = { val appName = app.desc.name - val eventLogInfo = app.desc.eventLogInfo.getOrElse { return None } - val eventLogDir = eventLogInfo.logDir - val eventCompressionCodec = eventLogInfo.compressionCodec - val appConf = new SparkConf - eventCompressionCodec.foreach { codec => - appConf.set("spark.eventLog.compress", "true") - appConf.set("spark.io.compression.codec", codec) - } - val replayerBus = new ReplayListenerBus(appConf) - val ui = new SparkUI( - appConf, - replayerBus, - "%s (finished)".format(appName), - "/history/%s".format(app.id)) - - // Do not call ui.bind() to avoid creating a new server for each application - ui.start() - val success = replayerBus.replay(eventLogDir) - if (success) Some(ui) else None + val eventLogDir = app.desc.eventLogDir.getOrElse { return false } + val fileSystem = Utils.getHadoopFileSystem(eventLogDir) + val eventLogInfo = EventLoggingListener.parseLoggingInfo(eventLogDir, fileSystem) + val eventLogPaths = eventLogInfo.logPaths + val compressionCodec = eventLogInfo.compressionCodec + if (!eventLogPaths.isEmpty) { + try { + val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec) + val ui = new SparkUI( + new SparkConf, replayBus, appName + " (completed)", "/history/" + app.id) + replayBus.replay() + app.desc.appUiUrl = ui.basePath + appIdToUI(app.id) = ui + webUi.attachSparkUI(ui) + return true + } catch { + case t: Throwable => + logError("Exception in replaying log for application %s (%s)".format(appName, app.id), t) + } + } else { + logWarning("Application %s (%s) has no valid logs: %s".format(appName, app.id, eventLogDir)) + } + false } /** Generate a new app ID given a app's submission date */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index cb092cb5d576b..b5cd4d2ea963f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -28,15 +28,16 @@ import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorInfo -import org.apache.spark.ui.UIUtils +import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[spark] class ApplicationPage(parent: MasterWebUI) { - val master = parent.masterActorRef - val timeout = parent.timeout +private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { + + private val master = parent.masterActorRef + private val timeout = parent.timeout /** Executor details for a particular application */ - def renderJson(request: HttpServletRequest): JValue = { + override def renderJson(request: HttpServletRequest): JValue = { val appId = request.getParameter("appId") val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] val state = Await.result(stateFuture, timeout) @@ -96,7 +97,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) { UIUtils.basicSparkPage(content, "Application: " + app.desc.name) } - def executorRow(executor: ExecutorInfo): Seq[Node] = { + private def executorRow(executor: ExecutorInfo): Seq[Node] = { {executor.id} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala similarity index 91% rename from core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala rename to core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 8c1d6c7cce450..7ca3b08a28728 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -25,17 +25,17 @@ import scala.xml.Node import akka.pattern.ask import org.json4s.JValue -import org.apache.spark.deploy.{JsonProtocol} +import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} -import org.apache.spark.ui.{WebUI, UIUtils} +import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[spark] class IndexPage(parent: MasterWebUI) { - val master = parent.masterActorRef - val timeout = parent.timeout +private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { + private val master = parent.masterActorRef + private val timeout = parent.timeout - def renderJson(request: HttpServletRequest): JValue = { + override def renderJson(request: HttpServletRequest): JValue = { val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] val state = Await.result(stateFuture, timeout) JsonProtocol.writeMasterState(state) @@ -139,7 +139,7 @@ private[spark] class IndexPage(parent: MasterWebUI) { UIUtils.basicSparkPage(content, "Spark Master at " + state.uri) } - def workerRow(worker: WorkerInfo): Seq[Node] = { + private def workerRow(worker: WorkerInfo): Seq[Node] = { {worker.id} @@ -154,8 +154,7 @@ private[spark] class IndexPage(parent: MasterWebUI) { } - - def appRow(app: ApplicationInfo): Seq[Node] = { + private def appRow(app: ApplicationInfo): Seq[Node] = { {app.id} @@ -169,14 +168,14 @@ private[spark] class IndexPage(parent: MasterWebUI) { {Utils.megabytesToString(app.desc.memoryPerSlave)} - {WebUI.formatDate(app.submitDate)} + {UIUtils.formatDate(app.submitDate)} {app.desc.user} {app.state.toString} - {WebUI.formatDuration(app.duration)} + {UIUtils.formatDuration(app.duration)} } - def driverRow(driver: DriverInfo): Seq[Node] = { + private def driverRow(driver: DriverInfo): Seq[Node] = { {driver.id} {driver.submitDate} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index bd75b2dfd0e07..a18b39fc95d64 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -17,13 +17,9 @@ package org.apache.spark.deploy.master.ui -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.servlet.ServletContextHandler - import org.apache.spark.Logging import org.apache.spark.deploy.master.Master -import org.apache.spark.ui.{ServerInfo, SparkUI} +import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.{AkkaUtils, Utils} @@ -31,72 +27,33 @@ import org.apache.spark.util.{AkkaUtils, Utils} * Web UI server for the standalone master. */ private[spark] -class MasterWebUI(val master: Master, requestedPort: Int) extends Logging { +class MasterWebUI(val master: Master, requestedPort: Int) + extends WebUI(master.securityMgr, requestedPort, master.conf) with Logging { + val masterActorRef = master.self val timeout = AkkaUtils.askTimeout(master.conf) - private val host = Utils.localHostName() - private val port = requestedPort - private val applicationPage = new ApplicationPage(this) - private val indexPage = new IndexPage(this) - private var serverInfo: Option[ServerInfo] = None - - private val handlers: Seq[ServletContextHandler] = { - master.masterMetricsSystem.getServletHandlers ++ - master.applicationMetricsSystem.getServletHandlers ++ - Seq[ServletContextHandler]( - createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static"), - createServletHandler("/app/json", - (request: HttpServletRequest) => applicationPage.renderJson(request), master.securityMgr), - createServletHandler("/app", - (request: HttpServletRequest) => applicationPage.render(request), master.securityMgr), - createServletHandler("/json", - (request: HttpServletRequest) => indexPage.renderJson(request), master.securityMgr), - createServletHandler("/", - (request: HttpServletRequest) => indexPage.render(request), master.securityMgr) - ) - } + initialize() - def bind() { - try { - serverInfo = Some(startJettyServer(host, port, handlers, master.conf)) - logInfo("Started Master web UI at http://%s:%d".format(host, boundPort)) - } catch { - case e: Exception => - logError("Failed to create Master JettyUtils", e) - System.exit(1) - } + /** Initialize all components of the server. */ + def initialize() { + attachPage(new ApplicationPage(this)) + attachPage(new MasterPage(this)) + attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static")) + master.masterMetricsSystem.getServletHandlers.foreach(attachHandler) + master.applicationMetricsSystem.getServletHandlers.foreach(attachHandler) } - def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) - /** Attach a reconstructed UI to this Master UI. Only valid after bind(). */ - def attachUI(ui: SparkUI) { + def attachSparkUI(ui: SparkUI) { assert(serverInfo.isDefined, "Master UI must be bound to a server before attaching SparkUIs") - val rootHandler = serverInfo.get.rootHandler - for (handler <- ui.handlers) { - rootHandler.addHandler(handler) - if (!handler.isStarted) { - handler.start() - } - } + ui.getHandlers.foreach(attachHandler) } /** Detach a reconstructed UI from this Master UI. Only valid after bind(). */ - def detachUI(ui: SparkUI) { + def detachSparkUI(ui: SparkUI) { assert(serverInfo.isDefined, "Master UI must be bound to a server before detaching SparkUIs") - val rootHandler = serverInfo.get.rootHandler - for (handler <- ui.handlers) { - if (handler.isStarted) { - handler.stop() - } - rootHandler.removeHandler(handler) - } - } - - def stop() { - assert(serverInfo.isDefined, "Attempted to stop a Master UI that was not bound to a server!") - serverInfo.get.server.stop() + ui.getHandlers.foreach(detachHandler) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 8a71ddda4cb5e..52c164ca3c574 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -64,6 +64,12 @@ private[spark] class Worker( val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 + val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", true) + // How often worker will clean up old app folders + val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 + // TTL for app folders/data; after TTL expires it will be cleaned up + val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) + // Index into masterUrls that we're currently trying to register with. var masterIndex = 0 @@ -122,8 +128,8 @@ private[spark] class Worker( host, port, cores, Utils.megabytesToString(memory))) logInfo("Spark home: " + sparkHome) createWorkDir() - webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) webUi.bind() registerWithMaster() @@ -179,12 +185,28 @@ private[spark] class Worker( registered = true changeMaster(masterUrl, masterWebUiUrl) context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + if (CLEANUP_ENABLED) { + context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, + CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + } case SendHeartbeat => masterLock.synchronized { if (connected) { master ! Heartbeat(workerId) } } + case WorkDirCleanup => + // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor + val cleanupFuture = concurrent.future { + logInfo("Cleaning up oldest application directories in " + workDir + " ...") + Utils.findOldFiles(workDir, APP_DATA_RETENTION_SECS) + .foreach(Utils.deleteRecursively) + } + cleanupFuture onFailure { + case e: Throwable => + logError("App dir cleanup failed: " + e.getMessage, e) + } + case MasterChanged(masterUrl, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterUrl) changeMaster(masterUrl, masterWebUiUrl) @@ -331,7 +353,6 @@ private[spark] class Worker( } private[spark] object Worker { - def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index d35d5be73ff97..3836bf219ed3e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -32,8 +32,8 @@ private[spark] class WorkerArguments(args: Array[String]) { var memory = inferDefaultMemory() var masters: Array[String] = null var workDir: String = null - - // Check for settings in environment variables + + // Check for settings in environment variables if (System.getenv("SPARK_WORKER_PORT") != null) { port = System.getenv("SPARK_WORKER_PORT").toInt } @@ -49,7 +49,7 @@ private[spark] class WorkerArguments(args: Array[String]) { if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } - + parse(args.toList) def parse(args: List[String]): Unit = args match { @@ -78,7 +78,7 @@ private[spark] class WorkerArguments(args: Array[String]) { case ("--work-dir" | "-d") :: value :: tail => workDir = value parse(tail) - + case "--webui-port" :: IntParam(value) :: tail => webUiPort = value parse(tail) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala new file mode 100644 index 0000000000000..fec1207948628 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.worker.ui + +import java.io.File +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.util.Utils + +private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") { + private val worker = parent.worker + private val workDir = parent.workDir + + def renderLog(request: HttpServletRequest): String = { + val defaultBytes = 100 * 1024 + + val appId = Option(request.getParameter("appId")) + val executorId = Option(request.getParameter("executorId")) + val driverId = Option(request.getParameter("driverId")) + val logType = request.getParameter("logType") + val offset = Option(request.getParameter("offset")).map(_.toLong) + val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + + val path = (appId, executorId, driverId) match { + case (Some(a), Some(e), None) => + s"${workDir.getPath}/$appId/$executorId/$logType" + case (None, None, Some(d)) => + s"${workDir.getPath}/$driverId/$logType" + case _ => + throw new Exception("Request must specify either application or driver identifiers") + } + + val (startByte, endByte) = getByteRange(path, offset, byteLength) + val file = new File(path) + val logLength = file.length + + val pre = s"==== Bytes $startByte-$endByte of $logLength of $path ====\n" + pre + Utils.offsetBytes(path, startByte, endByte) + } + + def render(request: HttpServletRequest): Seq[Node] = { + val defaultBytes = 100 * 1024 + val appId = Option(request.getParameter("appId")) + val executorId = Option(request.getParameter("executorId")) + val driverId = Option(request.getParameter("driverId")) + val logType = request.getParameter("logType") + val offset = Option(request.getParameter("offset")).map(_.toLong) + val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) + + val (path, params) = (appId, executorId, driverId) match { + case (Some(a), Some(e), None) => + (s"${workDir.getPath}/$a/$e/$logType", s"appId=$a&executorId=$e") + case (None, None, Some(d)) => + (s"${workDir.getPath}/$d/$logType", s"driverId=$d") + case _ => + throw new Exception("Request must specify either application or driver identifiers") + } + + val (startByte, endByte) = getByteRange(path, offset, byteLength) + val file = new File(path) + val logLength = file.length + val logText = {Utils.offsetBytes(path, startByte, endByte)} + val linkToMaster =

Back to Master

+ val range = Bytes {startByte.toString} - {endByte.toString} of {logLength} + + val backButton = + if (startByte > 0) { + + + + } + else { + + } + + val nextButton = + if (endByte < logLength) { + + + + } + else { + + } + + val content = + + + {linkToMaster} +
+
{backButton}
+
{range}
+
{nextButton}
+
+
+
+
{logText}
+
+ + + UIUtils.basicSparkPage(content, logType + " log page for " + appId) + } + + /** Determine the byte range for a log or log page. */ + private def getByteRange(path: String, offset: Option[Long], byteLength: Int): (Long, Long) = { + val defaultBytes = 100 * 1024 + val maxBytes = 1024 * 1024 + val file = new File(path) + val logLength = file.length() + val getOffset = offset.getOrElse(logLength - defaultBytes) + val startByte = + if (getOffset < 0) 0L + else if (getOffset > logLength) logLength + else getOffset + val logPageLength = math.min(byteLength, maxBytes) + val endByte = math.min(startByte + logPageLength, logLength) + (startByte, endByte) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala rename to core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 85200ab0e102d..d4513118ced05 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -28,15 +28,15 @@ import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.ui.UIUtils +import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[spark] class IndexPage(parent: WorkerWebUI) { +private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { val workerActor = parent.worker.self val worker = parent.worker val timeout = parent.timeout - def renderJson(request: HttpServletRequest): JValue = { + override def renderJson(request: HttpServletRequest): JValue = { val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] val workerState = Await.result(stateFuture, timeout) JsonProtocol.writeWorkerState(workerState) @@ -137,7 +137,7 @@ private[spark] class IndexPage(parent: WorkerWebUI) { .format(executor.appId, executor.execId)}>stdout stderr - + } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index de76a5d5eb7bc..0ad2edba2227f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -20,180 +20,44 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.servlet.ServletContextHandler - -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker -import org.apache.spark.ui.{JettyUtils, ServerInfo, SparkUI, UIUtils} +import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.AkkaUtils /** * Web UI server for the standalone worker. */ private[spark] -class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[Int] = None) - extends Logging { +class WorkerWebUI( + val worker: Worker, + val workDir: File, + port: Option[Int] = None) + extends WebUI(worker.securityMgr, WorkerWebUI.getUIPort(port, worker.conf), worker.conf) + with Logging { val timeout = AkkaUtils.askTimeout(worker.conf) - private val host = Utils.localHostName() - private val port = requestedPort.getOrElse( - worker.conf.get("worker.ui.port", WorkerWebUI.DEFAULT_PORT).toInt) - private val indexPage = new IndexPage(this) - private var serverInfo: Option[ServerInfo] = None - - private val handlers: Seq[ServletContextHandler] = { - worker.metricsSystem.getServletHandlers ++ - Seq[ServletContextHandler]( - createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static"), - createServletHandler("/log", - (request: HttpServletRequest) => log(request), worker.securityMgr), - createServletHandler("/logPage", - (request: HttpServletRequest) => logPage(request), worker.securityMgr), - createServletHandler("/json", - (request: HttpServletRequest) => indexPage.renderJson(request), worker.securityMgr), - createServletHandler("/", - (request: HttpServletRequest) => indexPage.render(request), worker.securityMgr) - ) - } - - def bind() { - try { - serverInfo = Some(JettyUtils.startJettyServer(host, port, handlers, worker.conf)) - logInfo("Started Worker web UI at http://%s:%d".format(host, boundPort)) - } catch { - case e: Exception => - logError("Failed to create Worker JettyUtils", e) - System.exit(1) - } - } - - def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) - - private def log(request: HttpServletRequest): String = { - val defaultBytes = 100 * 1024 - - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - - val path = (appId, executorId, driverId) match { - case (Some(a), Some(e), None) => - s"${workDir.getPath}/$appId/$executorId/$logType" - case (None, None, Some(d)) => - s"${workDir.getPath}/$driverId/$logType" - case _ => - throw new Exception("Request must specify either application or driver identifiers") - } - - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - - val pre = s"==== Bytes $startByte-$endByte of $logLength of $path ====\n" - pre + Utils.offsetBytes(path, startByte, endByte) - } - - private def logPage(request: HttpServletRequest): Seq[scala.xml.Node] = { - val defaultBytes = 100 * 1024 - val appId = Option(request.getParameter("appId")) - val executorId = Option(request.getParameter("executorId")) - val driverId = Option(request.getParameter("driverId")) - val logType = request.getParameter("logType") - val offset = Option(request.getParameter("offset")).map(_.toLong) - val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes) - - val (path, params) = (appId, executorId, driverId) match { - case (Some(a), Some(e), None) => - (s"${workDir.getPath}/$a/$e/$logType", s"appId=$a&executorId=$e") - case (None, None, Some(d)) => - (s"${workDir.getPath}/$d/$logType", s"driverId=$d") - case _ => - throw new Exception("Request must specify either application or driver identifiers") - } - - val (startByte, endByte) = getByteRange(path, offset, byteLength) - val file = new File(path) - val logLength = file.length - val logText = {Utils.offsetBytes(path, startByte, endByte)} - val linkToMaster =

Back to Master

- val range = Bytes {startByte.toString} - {endByte.toString} of {logLength} - - val backButton = - if (startByte > 0) { - - - - } - else { - - } - - val nextButton = - if (endByte < logLength) { - - - - } - else { - - } - - val content = - - - {linkToMaster} -
-
{backButton}
-
{range}
-
{nextButton}
-
-
-
-
{logText}
-
- - - UIUtils.basicSparkPage(content, logType + " log page for " + appId) - } - - /** Determine the byte range for a log or log page. */ - private def getByteRange(path: String, offset: Option[Long], byteLength: Int): (Long, Long) = { - val defaultBytes = 100 * 1024 - val maxBytes = 1024 * 1024 - val file = new File(path) - val logLength = file.length() - val getOffset = offset.getOrElse(logLength - defaultBytes) - val startByte = - if (getOffset < 0) 0L - else if (getOffset > logLength) logLength - else getOffset - val logPageLength = math.min(byteLength, maxBytes) - val endByte = math.min(startByte + logPageLength, logLength) - (startByte, endByte) - } - - def stop() { - assert(serverInfo.isDefined, "Attempted to stop a Worker UI that was not bound to a server!") - serverInfo.get.server.stop() + initialize() + + /** Initialize all components of the server. */ + def initialize() { + val logPage = new LogPage(this) + attachPage(logPage) + attachPage(new WorkerPage(this)) + attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static")) + attachHandler(createServletHandler("/log", + (request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr)) + worker.metricsSystem.getServletHandlers.foreach(attachHandler) } } private[spark] object WorkerWebUI { + val DEFAULT_PORT = 8081 val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR - val DEFAULT_PORT="8081" + + def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = { + requestedPort.getOrElse(conf.getInt("worker.ui.port", WorkerWebUI.DEFAULT_PORT)) + } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 3486092a140fb..6327ac01663f6 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -53,7 +53,8 @@ private[spark] class CoarseGrainedExecutorBackend( case RegisteredExecutor(sparkProperties) => logInfo("Successfully registered with driver") // Make this host instead of hostPort ? - executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties) + executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties, + false) case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) @@ -105,7 +106,8 @@ private[spark] object CoarseGrainedExecutorBackend { // set it val sparkHostPort = hostname + ":" + boundPort actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores), + Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, + sparkHostPort, cores), name = "Executor") workerUrl.foreach{ url => actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index aecb069e4202b..f89b2bffd1676 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -291,15 +291,19 @@ private[spark] class Executor( * Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes * created by the interpreter to the search path */ - private def createClassLoader(): ExecutorURLClassLoader = { - val loader = Thread.currentThread().getContextClassLoader + private def createClassLoader(): MutableURLClassLoader = { + val currentLoader = Utils.getContextOrSparkClassLoader // For each of the jars in the jarSet, add them to the class loader. // We assume each of the files has already been fetched. val urls = currentJars.keySet.map { uri => new File(uri.split("/").last).toURI.toURL }.toArray - new ExecutorURLClassLoader(urls, loader) + val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false) + userClassPathFirst match { + case true => new ChildExecutorURLClassLoader(urls, currentLoader) + case false => new ExecutorURLClassLoader(urls, currentLoader) + } } /** @@ -310,11 +314,14 @@ private[spark] class Executor( val classUri = conf.get("spark.repl.class.uri", null) if (classUri != null) { logInfo("Using REPL class URI: " + classUri) + val userClassPathFirst: java.lang.Boolean = + conf.getBoolean("spark.files.userClassPathFirst", false) try { val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader") .asInstanceOf[Class[_ <: ClassLoader]] - val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader]) - constructor.newInstance(classUri, parent) + val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader], + classOf[Boolean]) + constructor.newInstance(classUri, parent, userClassPathFirst) } catch { case _: ClassNotFoundException => logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!") diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala index 210f3dbeebaca..38be2c58b333f 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala @@ -34,13 +34,19 @@ object ExecutorExitCode { logging the exception. */ val UNCAUGHT_EXCEPTION_TWICE = 51 - /** The default uncaught exception handler was reached, and the uncaught exception was an + /** The default uncaught exception handler was reached, and the uncaught exception was an OutOfMemoryError. */ val OOM = 52 /** DiskStore failed to create a local temporary directory after many attempts. */ val DISK_STORE_FAILED_TO_CREATE_DIR = 53 + /** TachyonStore failed to initialize after many attempts. */ + val TACHYON_STORE_FAILED_TO_INITIALIZE = 54 + + /** TachyonStore failed to create a local temporary directory after many attempts. */ + val TACHYON_STORE_FAILED_TO_CREATE_DIR = 55 + def explainExitCode(exitCode: Int): String = { exitCode match { case UNCAUGHT_EXCEPTION => "Uncaught exception" @@ -48,7 +54,10 @@ object ExecutorExitCode { case OOM => "OutOfMemoryError" case DISK_STORE_FAILED_TO_CREATE_DIR => "Failed to create local directory (bad spark.local.dir?)" - case _ => + case TACHYON_STORE_FAILED_TO_INITIALIZE => "TachyonStore failed to initialize." + case TACHYON_STORE_FAILED_TO_CREATE_DIR => + "TachyonStore failed to create a local temporary directory." + case _ => "Unknown executor exit code (" + exitCode + ")" + ( if (exitCode > 128) { " (died from signal " + (exitCode - 128) + "?)" diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 127f5e90f3e1a..0ed52cfe9df61 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.FileSystem import org.apache.spark.metrics.source.Source -class ExecutorSource(val executor: Executor, executorId: String) extends Source { +private[spark] class ExecutorSource(val executor: Executor, executorId: String) extends Source { private def fileStats(scheme: String) : Option[FileSystem.Statistics] = FileSystem.getAllStatistics().filter(s => s.getScheme.equals(scheme)).headOption diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala index f9bfe8ed2f5ba..218ed7b5d2d39 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala @@ -19,13 +19,56 @@ package org.apache.spark.executor import java.net.{URLClassLoader, URL} +import org.apache.spark.util.ParentClassLoader + /** * The addURL method in URLClassLoader is protected. We subclass it to make this accessible. + * We also make changes so user classes can come before the default classes. */ + +private[spark] trait MutableURLClassLoader extends ClassLoader { + def addURL(url: URL) + def getURLs: Array[URL] +} + +private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) + extends MutableURLClassLoader { + + private object userClassLoader extends URLClassLoader(urls, null){ + override def addURL(url: URL) { + super.addURL(url) + } + override def findClass(name: String): Class[_] = { + super.findClass(name) + } + } + + private val parentClassLoader = new ParentClassLoader(parent) + + override def findClass(name: String): Class[_] = { + try { + userClassLoader.findClass(name) + } catch { + case e: ClassNotFoundException => { + parentClassLoader.loadClass(name) + } + } + } + + def addURL(url: URL) { + userClassLoader.addURL(url) + } + + def getURLs() = { + userClassLoader.getURLs() + } +} + private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader) - extends URLClassLoader(urls, parent) { + extends URLClassLoader(urls, parent) with MutableURLClassLoader { override def addURL(url: URL) { super.addURL(url) } } + diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 88625e79a5c68..e4f02a4be0b97 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,8 +17,14 @@ package org.apache.spark.executor +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.{BlockId, BlockStatus} +/** + * :: DeveloperApi :: + * Metrics tracked during the execution of a task. + */ +@DeveloperApi class TaskMetrics extends Serializable { /** * Host's name the task runs on @@ -77,11 +83,16 @@ class TaskMetrics extends Serializable { var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None } -object TaskMetrics { - private[spark] def empty(): TaskMetrics = new TaskMetrics +private[spark] object TaskMetrics { + def empty(): TaskMetrics = new TaskMetrics } +/** + * :: DeveloperApi :: + * Metrics pertaining to shuffle data read in a given task. + */ +@DeveloperApi class ShuffleReadMetrics extends Serializable { /** * Absolute time when this task finished reading shuffle data @@ -116,6 +127,11 @@ class ShuffleReadMetrics extends Serializable { var remoteBytesRead: Long = _ } +/** + * :: DeveloperApi :: + * Metrics pertaining to shuffle data written in a given task. + */ +@DeveloperApi class ShuffleWriteMetrics extends Serializable { /** * Number of bytes written for the shuffle by this task diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala new file mode 100644 index 0000000000000..80d055a89573b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat +import org.apache.hadoop.mapreduce.RecordReader +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader +import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit + +/** + * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for + * reading whole text files. Each file is read as key-value pair, where the key is the file path and + * the value is the entire content of file. + */ + +private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] { + override protected def isSplitable(context: JobContext, file: Path): Boolean = false + + override def createRecordReader( + split: InputSplit, + context: TaskAttemptContext): RecordReader[String, String] = { + + new CombineFileRecordReader[String, String]( + split.asInstanceOf[CombineFileSplit], + context, + classOf[WholeTextFileRecordReader]) + } + + /** + * Allow minSplits set by end-user in order to keep compatibility with old Hadoop API. + */ + def setMaxSplitSize(context: JobContext, minSplits: Int) { + val files = listStatus(context) + val totalLen = files.map { file => + if (file.isDir) 0L else file.getLen + }.sum + val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minSplits == 0) 1 else minSplits)).toLong + super.setMaxSplitSize(maxSplitSize) + } +} diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala new file mode 100644 index 0000000000000..c3dabd2e79995 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import com.google.common.io.{ByteStreams, Closeables} + +import org.apache.hadoop.io.Text +import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit +import org.apache.hadoop.mapreduce.RecordReader +import org.apache.hadoop.mapreduce.TaskAttemptContext + +/** + * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file + * out in a key-value pair, where the key is the file path and the value is the entire content of + * the file. + */ +private[spark] class WholeTextFileRecordReader( + split: CombineFileSplit, + context: TaskAttemptContext, + index: Integer) + extends RecordReader[String, String] { + + private val path = split.getPath(index) + private val fs = path.getFileSystem(context.getConfiguration) + + // True means the current file has been processed, then skip it. + private var processed = false + + private val key = path.toString + private var value: String = null + + override def initialize(split: InputSplit, context: TaskAttemptContext) = {} + + override def close() = {} + + override def getProgress = if (processed) 1.0f else 0.0f + + override def getCurrentKey = key + + override def getCurrentValue = value + + override def nextKeyValue = { + if (!processed) { + val fileIn = fs.open(path) + val innerBuffer = ByteStreams.toByteArray(fileIn) + + value = new Text(innerBuffer).toString + Closeables.close(fileIn, false) + + processed = true + true + } else { + false + } + } +} diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 059e58824c39b..e1a5ee316bb69 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -23,11 +23,18 @@ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf +import org.apache.spark.annotation.DeveloperApi /** + * :: DeveloperApi :: * CompressionCodec allows the customization of choosing different compression implementations * to be used in block storage. + * + * Note: The wire protocol for a codec is not guaranteed compatible across versions of Spark. + * This is intended for use as an internal compression utility within a single + * Spark application. */ +@DeveloperApi trait CompressionCodec { def compressedOutputStream(s: OutputStream): OutputStream @@ -52,8 +59,14 @@ private[spark] object CompressionCodec { /** + * :: DeveloperApi :: * LZF implementation of [[org.apache.spark.io.CompressionCodec]]. + * + * Note: The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ +@DeveloperApi class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { @@ -65,9 +78,15 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { /** + * :: DeveloperApi :: * Snappy implementation of [[org.apache.spark.io.CompressionCodec]]. * Block size can be configured by spark.io.compression.snappy.block.size. + * + * Note: The wire protocol for this codec is not guaranteed to be compatible across versions + * of Spark. This is intended for use as an internal compression utility within a single Spark + * application. */ +@DeveloperApi class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index 3e3e18c3537d0..1b7a5d1f1980a 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import scala.util.matching.Regex import org.apache.spark.Logging +import org.apache.spark.util.Utils private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging { @@ -50,7 +51,7 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi try { is = configFile match { case Some(f) => new FileInputStream(f) - case None => getClass.getClassLoader.getResourceAsStream(METRICS_CONF) + case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF) } if (is != null) { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala index 64eac73605388..05852f1f98993 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/ConsoleSink.scala @@ -25,7 +25,7 @@ import com.codahale.metrics.{ConsoleReporter, MetricRegistry} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class ConsoleSink(val property: Properties, val registry: MetricRegistry, +private[spark] class ConsoleSink(val property: Properties, val registry: MetricRegistry, securityMgr: SecurityManager) extends Sink { val CONSOLE_DEFAULT_PERIOD = 10 val CONSOLE_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala index 544848d4150b6..542dce65366b2 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/CsvSink.scala @@ -26,7 +26,7 @@ import com.codahale.metrics.{CsvReporter, MetricRegistry} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class CsvSink(val property: Properties, val registry: MetricRegistry, +private[spark] class CsvSink(val property: Properties, val registry: MetricRegistry, securityMgr: SecurityManager) extends Sink { val CSV_KEY_PERIOD = "period" val CSV_KEY_UNIT = "unit" @@ -45,7 +45,7 @@ class CsvSink(val property: Properties, val registry: MetricRegistry, case Some(s) => TimeUnit.valueOf(s.toUpperCase()) case None => TimeUnit.valueOf(CSV_DEFAULT_UNIT) } - + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) val pollDir = Option(property.getProperty(CSV_KEY_DIR)) match { diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 7f0a2fd16fa99..aeb4ad44a0647 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -27,7 +27,7 @@ import com.codahale.metrics.graphite.{Graphite, GraphiteReporter} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem -class GraphiteSink(val property: Properties, val registry: MetricRegistry, +private[spark] class GraphiteSink(val property: Properties, val registry: MetricRegistry, securityMgr: SecurityManager) extends Sink { val GRAPHITE_DEFAULT_PERIOD = 10 val GRAPHITE_DEFAULT_UNIT = "SECONDS" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index 3b5edd5c376f0..ed27234b4e760 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -22,7 +22,7 @@ import java.util.Properties import com.codahale.metrics.{JmxReporter, MetricRegistry} import org.apache.spark.SecurityManager -class JmxSink(val property: Properties, val registry: MetricRegistry, +private[spark] class JmxSink(val property: Properties, val registry: MetricRegistry, securityMgr: SecurityManager) extends Sink { val reporter: JmxReporter = JmxReporter.forRegistry(registry).build() diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 854b52c510e3d..571539ba5e467 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -30,7 +30,7 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.SecurityManager import org.apache.spark.ui.JettyUtils._ -class MetricsServlet(val property: Properties, val registry: MetricRegistry, +private[spark] class MetricsServlet(val property: Properties, val registry: MetricRegistry, securityMgr: SecurityManager) extends Sink { val SERVLET_KEY_PATH = "path" val SERVLET_KEY_SAMPLE = "sample" diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala index 3a739aa563eae..6f2b5a06027ea 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Sink.scala @@ -17,7 +17,7 @@ package org.apache.spark.metrics.sink -trait Sink { +private[spark] trait Sink { def start: Unit def stop: Unit } diff --git a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala index 75cb2b8973aa1..f865f9648a91e 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/JvmSource.scala @@ -20,7 +20,7 @@ package org.apache.spark.metrics.source import com.codahale.metrics.MetricRegistry import com.codahale.metrics.jvm.{GarbageCollectorMetricSet, MemoryUsageGaugeSet} -class JvmSource extends Source { +private[spark] class JvmSource extends Source { val sourceName = "jvm" val metricRegistry = new MetricRegistry() diff --git a/core/src/main/scala/org/apache/spark/metrics/source/Source.scala b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala index 3fee55cc6dcd5..1dda2cd83b2a9 100644 --- a/core/src/main/scala/org/apache/spark/metrics/source/Source.scala +++ b/core/src/main/scala/org/apache/spark/metrics/source/Source.scala @@ -19,7 +19,7 @@ package org.apache.spark.metrics.source import com.codahale.metrics.MetricRegistry -trait Source { +private[spark] trait Source { def sourceName: String def metricRegistry: MetricRegistry } diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala index 2f7576c53b482..3ffaaab23d0f5 100644 --- a/core/src/main/scala/org/apache/spark/network/Connection.scala +++ b/core/src/main/scala/org/apache/spark/network/Connection.scala @@ -248,14 +248,14 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector, } } - // outbox is used as a lock - ensure that it is always used as a leaf (since methods which + // outbox is used as a lock - ensure that it is always used as a leaf (since methods which // lock it are invoked in context of other locks) private val outbox = new Outbox() /* - This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly - different purpose. This flag is to see if we need to force reregister for write even when we + This is orthogonal to whether we have pending bytes to write or not - and satisfies a slightly + different purpose. This flag is to see if we need to force reregister for write even when we do not have any pending bytes to write to socket. - This can happen due to a race between adding pending buffers, and checking for existing of + This can happen due to a race between adding pending buffers, and checking for existing of data as detailed in https://github.com/mesos/spark/pull/791 */ private var needForceReregister = false diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala index ffaab677d411a..d579c165a1917 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionId.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionId.scala @@ -18,7 +18,7 @@ package org.apache.spark.network private[spark] case class ConnectionId(connectionManagerId: ConnectionManagerId, uniqId: Int) { - override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId + override def toString = connectionManagerId.host + "_" + connectionManagerId.port + "_" + uniqId } private[spark] object ConnectionId { @@ -26,9 +26,9 @@ private[spark] object ConnectionId { def createConnectionIdFromString(connectionIdString: String): ConnectionId = { val res = connectionIdString.split("_").map(_.trim()) if (res.size != 3) { - throw new Exception("Error converting ConnectionId string: " + connectionIdString + + throw new Exception("Error converting ConnectionId string: " + connectionIdString + " to a ConnectionId Object") } new ConnectionId(new ConnectionManagerId(res(0), res(1).toInt), res(2).toInt) - } + } } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 6b0a972f0bbe0..cfee41c61362e 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,7 +17,6 @@ package org.apache.spark.network -import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ @@ -80,7 +79,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, private val serverChannel = ServerSocketChannel.open() // used to track the SendingConnections waiting to do SASL negotiation - private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] + private val connectionsAwaitingSasl = new HashMap[ConnectionId, SendingConnection] with SynchronizedMap[ConnectionId, SendingConnection] private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection] @@ -142,7 +141,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } finally { writeRunnableStarted.synchronized { writeRunnableStarted -= key - val needReregister = register || conn.resetForceReregister() + val needReregister = register || conn.resetForceReregister() if (needReregister && conn.changeInterestForWrite()) { conn.registerInterest() } @@ -510,7 +509,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, private def handleClientAuthentication( waitingConn: SendingConnection, - securityMsg: SecurityMessage, + securityMsg: SecurityMessage, connectionId : ConnectionId) { if (waitingConn.isSaslComplete()) { logDebug("Client sasl completed for id: " + waitingConn.connectionId) @@ -531,7 +530,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } return } - var securityMsgResp = SecurityMessage.fromResponse(replyToken, + var securityMsgResp = SecurityMessage.fromResponse(replyToken, securityMsg.getConnectionId.toString()) var message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security message") @@ -547,7 +546,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } private def handleServerAuthentication( - connection: Connection, + connection: Connection, securityMsg: SecurityMessage, connectionId: ConnectionId) { if (!connection.isSaslComplete()) { @@ -562,7 +561,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } replyToken = connection.sparkSaslServer.response(securityMsg.getToken) if (connection.isSaslComplete()) { - logDebug("Server sasl completed: " + connection.connectionId) + logDebug("Server sasl completed: " + connection.connectionId) } else { logDebug("Server sasl not completed: " + connection.connectionId) } @@ -572,7 +571,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, var message = securityMsgResp.toBufferMessage if (message == null) throw new Exception("Error creating security Message") sendSecurityMessage(connection.getRemoteConnectionManagerId(), message) - } + } } catch { case e: Exception => { logError("Error in server auth negotiation: " + e) @@ -582,7 +581,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } } else { - logDebug("connection already established for this connection id: " + connection.connectionId) + logDebug("connection already established for this connection id: " + connection.connectionId) } } @@ -610,8 +609,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, return true } else { if (!conn.isSaslComplete()) { - // We could handle this better and tell the client we need to do authentication - // negotiation, but for now just ignore them. + // We could handle this better and tell the client we need to do authentication + // negotiation, but for now just ignore them. logError("message sent that is not security negotiation message on connection " + "not authenticated yet, ignoring it!!") return true @@ -710,11 +709,11 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } } else { - logDebug("Sasl already established ") + logDebug("Sasl already established ") } } - // allow us to add messages to the inbox for doing sasl negotiating + // allow us to add messages to the inbox for doing sasl negotiating private def sendSecurityMessage(connManagerId: ConnectionManagerId, message: Message) { def startNewConnection(): SendingConnection = { val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port) @@ -773,7 +772,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, if (((clock.getTime() - startTime) >= (authTimeout * 1000)) && (!connection.isSaslComplete())) { // took to long to authenticate the connection, something probably went wrong - throw new Exception("Took to long for authentication to " + connectionManagerId + + throw new Exception("Took to long for authentication to " + connectionManagerId + ", waited " + authTimeout + "seconds, failing.") } } @@ -795,7 +794,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf, } } case None => { - logError("no messageStatus for failed message id: " + message.id) + logError("no messageStatus for failed message id: " + message.id) } } } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala index 9d9b9dbdd5331..4894ecd41f6eb 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala @@ -37,11 +37,11 @@ private[spark] object ConnectionManagerTest extends Logging{ "[size of msg in MB (integer)] [count] [await time in seconds)] ") System.exit(1) } - + if (args(0).startsWith("local")) { println("This runs only on a mesos cluster") } - + val sc = new SparkContext(args(0), "ConnectionManagerTest") val slavesFile = Source.fromFile(args(1)) val slaves = slavesFile.mkString.split("\n") @@ -50,7 +50,7 @@ private[spark] object ConnectionManagerTest extends Logging{ /* println("Slaves") */ /* slaves.foreach(println) */ val tasknum = if (args.length > 2) args(2).toInt else slaves.length - val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 + val size = ( if (args.length > 3) (args(3).toInt) else 10 ) * 1024 * 1024 val count = if (args.length > 4) args(4).toInt else 3 val awaitTime = (if (args.length > 5) args(5).toInt else 600 ).second println("Running " + count + " rounds of test: " + "parallel tasks = " + tasknum + ", " + @@ -64,16 +64,16 @@ private[spark] object ConnectionManagerTest extends Logging{ (0 until count).foreach(i => { val resultStrs = sc.parallelize(0 until tasknum, tasknum).map(i => { val connManager = SparkEnv.get.connectionManager - val thisConnManagerId = connManager.id - connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + val thisConnManagerId = connManager.id + connManager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { logInfo("Received [" + msg + "] from [" + id + "]") None }) val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip - - val startTime = System.currentTimeMillis + + val startTime = System.currentTimeMillis val futures = slaveConnManagerIds.filter(_ != thisConnManagerId).map{ slaveConnManagerId => { val bufferMessage = Message.createBufferMessage(buffer.duplicate) @@ -84,7 +84,7 @@ private[spark] object ConnectionManagerTest extends Logging{ val results = futures.map(f => Await.result(f, awaitTime)) val finishTime = System.currentTimeMillis Thread.sleep(5000) - + val mb = size * results.size / 1024.0 / 1024.0 val ms = finishTime - startTime val resultStr = thisConnManagerId + " Sent " + mb + " MB in " + ms + " ms at " + (mb / ms * @@ -92,11 +92,11 @@ private[spark] object ConnectionManagerTest extends Logging{ logInfo(resultStr) resultStr }).collect() - - println("---------------------") - println("Run " + i) + + println("---------------------") + println("Run " + i) resultStrs.foreach(println) - println("---------------------") + println("---------------------") }) } } diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index 2b41c403b2e0a..9dc51e0d401f8 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -18,7 +18,7 @@ package org.apache.spark.network import java.nio.ByteBuffer -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf} private[spark] object ReceiverTest { def main(args: Array[String]) { diff --git a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala index 0d9f743b3624b..a1dfc4094cca7 100644 --- a/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala +++ b/core/src/main/scala/org/apache/spark/network/SecurityMessage.scala @@ -26,33 +26,33 @@ import org.apache.spark._ import org.apache.spark.network._ /** - * SecurityMessage is class that contains the connectionId and sasl token + * SecurityMessage is class that contains the connectionId and sasl token * used in SASL negotiation. SecurityMessage has routines for converting * it to and from a BufferMessage so that it can be sent by the ConnectionManager * and easily consumed by users when received. * The api was modeled after BlockMessage. * - * The connectionId is the connectionId of the client side. Since + * The connectionId is the connectionId of the client side. Since * message passing is asynchronous and its possible for the server side (receiving) - * to get multiple different types of messages on the same connection the connectionId - * is used to know which connnection the security message is intended for. - * + * to get multiple different types of messages on the same connection the connectionId + * is used to know which connnection the security message is intended for. + * * For instance, lets say we are node_0. We need to send data to node_1. The node_0 side * is acting as a client and connecting to node_1. SASL negotiation has to occur - * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. - * node_1 receives the message from node_0 but before it can process it and send a response, - * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 - * and sends a security message of its own to authenticate as a client. Now node_0 gets - * the message and it needs to decide if this message is in response to it being a client - * (from the first send) or if its just node_1 trying to connect to it to send data. This + * between node_0 and node_1 before node_1 trusts node_0 so node_0 sends a security message. + * node_1 receives the message from node_0 but before it can process it and send a response, + * some thread on node_1 decides it needs to send data to node_0 so it connects to node_0 + * and sends a security message of its own to authenticate as a client. Now node_0 gets + * the message and it needs to decide if this message is in response to it being a client + * (from the first send) or if its just node_1 trying to connect to it to send data. This * is where the connectionId field is used. node_0 can lookup the connectionId to see if * it is in response to it being a client or if its in response to someone sending other data. - * + * * The format of a SecurityMessage as its sent is: * - Length of the ConnectionId - * - ConnectionId + * - ConnectionId * - Length of the token - * - Token + * - Token */ private[spark] class SecurityMessage() extends Logging { @@ -61,13 +61,13 @@ private[spark] class SecurityMessage() extends Logging { def set(byteArr: Array[Byte], newconnectionId: String) { if (byteArr == null) { - token = new Array[Byte](0) + token = new Array[Byte](0) } else { token = byteArr } connectionId = newconnectionId } - + /** * Read the given buffer and set the members of this class. */ @@ -91,17 +91,17 @@ private[spark] class SecurityMessage() extends Logging { buffer.clear() set(buffer) } - + def getConnectionId: String = { return connectionId } - + def getToken: Array[Byte] = { return token } - + /** - * Create a BufferMessage that can be sent by the ConnectionManager containing + * Create a BufferMessage that can be sent by the ConnectionManager containing * the security information from this class. * @return BufferMessage */ @@ -110,12 +110,12 @@ private[spark] class SecurityMessage() extends Logging { val buffers = new ArrayBuffer[ByteBuffer]() // 4 bytes for the length of the connectionId - // connectionId is of type char so multiple the length by 2 to get number of bytes + // connectionId is of type char so multiple the length by 2 to get number of bytes // 4 bytes for the length of token // token is a byte buffer so just take the length var buffer = ByteBuffer.allocate(4 + connectionId.length() * 2 + 4 + token.length) buffer.putInt(connectionId.length()) - connectionId.foreach((x: Char) => buffer.putChar(x)) + connectionId.foreach((x: Char) => buffer.putChar(x)) buffer.putInt(token.length) if (token.length > 0) { @@ -123,7 +123,7 @@ private[spark] class SecurityMessage() extends Logging { } buffer.flip() buffers += buffer - + var message = Message.createBufferMessage(buffers) logDebug("message total size is : " + message.size) message.isSecurityNeg = true @@ -136,7 +136,7 @@ private[spark] class SecurityMessage() extends Logging { } private[spark] object SecurityMessage { - + /** * Convert the given BufferMessage to a SecurityMessage by parsing the contents * of the BufferMessage and populating the SecurityMessage fields. diff --git a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala index 4164e81d3a8ae..136c1912045aa 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/FileHeader.scala @@ -36,8 +36,8 @@ private[spark] class FileHeader ( if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) { buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes) } else { - throw new Exception("too long header " + buf.readableBytes) - logInfo("too long header") + throw new Exception("too long header " + buf.readableBytes) + logInfo("too long header") } buf } diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 2625a7f6a575a..59bbb1171f239 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -32,7 +32,16 @@ package org.apache * * Java programmers should reference the [[spark.api.java]] package * for Spark programming APIs in Java. + * + * Classes and methods marked with + * Experimental are user-facing features which have not been officially adopted by the + * Spark project. These are subject to change or removal in minor releases. + * + * Classes and methods marked with + * Developer API are intended for advanced users want to extend Spark through lower + * level interfaces. These are subject to changes or removal in minor releases. */ + package object spark { // For package docs only } diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index 5f4450859cc9b..aed0353344427 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -17,9 +17,13 @@ package org.apache.spark.partial +import org.apache.spark.annotation.Experimental + /** - * A Double with error bars on it. + * :: Experimental :: + * A Double value with error bars and associated confidence. */ +@Experimental class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, val high: Double) { override def toString(): String = "[%.3f, %.3f]".format(low, high) } diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala index 40b70baabcad9..8bb78123e3c9c 100644 --- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala @@ -22,36 +22,33 @@ import java.util.{HashMap => JHashMap} import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.Map import scala.collection.mutable.HashMap +import scala.reflect.ClassTag import cern.jet.stat.Probability -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} + +import org.apache.spark.util.collection.OpenHashMap /** * An ApproximateEvaluator for counts by key. Returns a map of key to confidence interval. */ -private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Double) - extends ApproximateEvaluator[OLMap[T], Map[T, BoundedDouble]] { +private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, confidence: Double) + extends ApproximateEvaluator[OpenHashMap[T,Long], Map[T, BoundedDouble]] { var outputsMerged = 0 - var sums = new OLMap[T] // Sum of counts for each key + var sums = new OpenHashMap[T,Long]() // Sum of counts for each key - override def merge(outputId: Int, taskResult: OLMap[T]) { + override def merge(outputId: Int, taskResult: OpenHashMap[T,Long]) { outputsMerged += 1 - val iter = taskResult.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - sums.put(entry.getKey, sums.getLong(entry.getKey) + entry.getLongValue) + taskResult.foreach { case (key, value) => + sums.changeValue(key, value, _ + value) } } override def currentResult(): Map[T, BoundedDouble] = { if (outputsMerged == totalOutputs) { val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getLongValue() - result(entry.getKey) = new BoundedDouble(sum, 1.0, sum, sum) + sums.foreach { case (key, sum) => + result(key) = new BoundedDouble(sum, 1.0, sum, sum) } result } else if (outputsMerged == 0) { @@ -60,16 +57,13 @@ private[spark] class GroupedCountEvaluator[T](totalOutputs: Int, confidence: Dou val p = outputsMerged.toDouble / totalOutputs val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2) val result = new JHashMap[T, BoundedDouble](sums.size) - val iter = sums.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - val sum = entry.getLongValue + sums.foreach { case (key, sum) => val mean = (sum + 1 - p) / p val variance = (sum + 1) * (1 - p) / (p * p) val stdev = math.sqrt(variance) val low = mean - confFactor * stdev val high = mean + confFactor * stdev - result(entry.getKey) = new BoundedDouble(mean, confidence, low, high) + result(key) = new BoundedDouble(mean, confidence, low, high) } result } diff --git a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala index 812368e04ac0d..cadd0c7ed19ba 100644 --- a/core/src/main/scala/org/apache/spark/partial/PartialResult.scala +++ b/core/src/main/scala/org/apache/spark/partial/PartialResult.scala @@ -17,6 +17,9 @@ package org.apache.spark.partial +import org.apache.spark.annotation.Experimental + +@Experimental class PartialResult[R](initialVal: R, isFinal: Boolean) { private var finalValue: Option[R] = if (isFinal) Some(initialVal) else None private var failure: Option[Exception] = None @@ -41,7 +44,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) { } } - /** + /** * Set a handler to be called when this PartialResult completes. Only one completion handler * is supported per PartialResult. */ @@ -57,7 +60,7 @@ class PartialResult[R](initialVal: R, isFinal: Boolean) { return this } - /** + /** * Set a handler to be called if this PartialResult's job fails. Only one failure handler * is supported per PartialResult. */ diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index d1c74a5063510..aed951a40b40c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -24,11 +24,14 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.annotation.Experimental /** + * :: Experimental :: * A set of asynchronous RDD actions available through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ +@Experimental class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging { /** diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 9aa454a5c8b88..c6e79557f08a1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap} import org.apache.spark.serializer.Serializer @@ -51,12 +52,17 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep] } /** + * :: DeveloperApi :: * A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a * tuple with the list of values for that key. * + * Note: This is an internal API. We recommend users use RDD.coGroup(...) instead of + * instantiating this directly. + * @param rdds parent RDDs. - * @param part partitioner used to partition the shuffle output. + * @param part partitioner used to partition the shuffle output */ +@DeveloperApi class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner) extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 4e82b51313bf0..44401a663440c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -32,7 +32,7 @@ import org.apache.spark._ * @param parentsIndices list of indices in the parent that have been coalesced into this partition * @param preferredLocation the preferred location for this partition */ -case class CoalescedRDDPartition( +private[spark] case class CoalescedRDDPartition( index: Int, @transient rdd: RDD[_], parentsIndices: Array[Int], @@ -70,7 +70,7 @@ case class CoalescedRDDPartition( * @param maxPartitions number of desired partitions in the coalesced RDD * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance */ -class CoalescedRDD[T: ClassTag]( +private[spark] class CoalescedRDD[T: ClassTag]( @transient var prev: RDD[T], maxPartitions: Int, balanceSlack: Double = 0.10) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a7b6b3b5146ce..9ca971c8a4c27 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import org.apache.spark.annotation.Experimental import org.apache.spark.{TaskContext, Logging} import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.MeanEvaluator @@ -51,7 +52,7 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { /** Compute the standard deviation of this RDD's elements. */ def stdev(): Double = stats().stdev - /** + /** * Compute the sample standard deviation of this RDD's elements (which corrects for bias in * estimating the standard deviation by dividing by N-1 instead of N). */ @@ -63,14 +64,22 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { */ def sampleVariance(): Double = stats().sampleVariance - /** (Experimental) Approximate operation to return the mean within a timeout. */ + /** + * :: Experimental :: + * Approximate operation to return the mean within a timeout. + */ + @Experimental def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) val evaluator = new MeanEvaluator(self.partitions.size, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } - /** (Experimental) Approximate operation to return the sum within a timeout. */ + /** + * :: Experimental :: + * Approximate operation to return the sum within a timeout. + */ + @Experimental def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns) val evaluator = new SumEvaluator(self.partitions.size, confidence) @@ -114,13 +123,13 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * e.g. for the array * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 - * + * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 + * * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets * to true. * buckets must be sorted and not contain any duplicates. - * buckets array must be at least two elements + * buckets array must be at least two elements * All NaN entries are treated the same. If you have a NaN bucket it must be * the maximum value of the last position and all NaN entries will be counted * in that bucket. diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala index a84e5f9fd8ef8..a2d7e344cf1b2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala @@ -22,9 +22,9 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, SparkContext, TaskContext} /** - * An RDD that is empty, i.e. has no element in it. + * An RDD that has no partitions and no elements. */ -class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { +private[spark] class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { override def getPartitions: Array[Partition] = Array.empty diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 3af008bd72378..6811e1abb8b70 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.TaskID import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.NextIterator @@ -70,9 +71,13 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp } /** + * :: DeveloperApi :: * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the older MapReduce API (`org.apache.hadoop.mapred`). * + * Note: Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.hadoopRDD()]] + * * @param sc The SparkContext to associate the RDD with. * @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed * variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job. @@ -84,6 +89,7 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp * @param valueClass Class of the value associated with the inputFormatClass. * @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate. */ +@DeveloperApi class HadoopRDD[K, V]( sc: SparkContext, broadcastedConf: Broadcast[SerializableWritable[Configuration]], diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 1b503743ac117..a76a070b5b863 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -27,7 +27,7 @@ import org.apache.spark.util.NextIterator private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { override def index = idx } - +// TODO: Expose a jdbcRDD function in SparkContext and mark this as semi-private /** * An RDD that executes an SQL query on a JDBC connection and reads results. * For usage example, see test case JdbcRDDSuite. diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 461a749eac48b..8684b645bc361 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -24,10 +24,18 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext} - -private[spark] -class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.input.WholeTextFileInputFormat +import org.apache.spark.InterruptibleIterator +import org.apache.spark.Logging +import org.apache.spark.Partition +import org.apache.spark.SerializableWritable +import org.apache.spark.{SparkContext, TaskContext} + +private[spark] class NewHadoopPartition( + rddId: Int, + val index: Int, + @transient rawSplit: InputSplit with Writable) extends Partition { val serializableHadoopSplit = new SerializableWritable(rawSplit) @@ -36,15 +44,20 @@ class NewHadoopPartition(rddId: Int, val index: Int, @transient rawSplit: InputS } /** + * :: DeveloperApi :: * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). * + * Note: Instantiating this class directly is not recommended, please use + * [[org.apache.spark.SparkContext.newAPIHadoopRDD()]] + * * @param sc The SparkContext to associate the RDD with. * @param inputFormatClass Storage format of the data to be read. * @param keyClass Class of the key associated with the inputFormatClass. * @param valueClass Class of the value associated with the inputFormatClass. * @param conf The Hadoop configuration. */ +@DeveloperApi class NewHadoopRDD[K, V]( sc : SparkContext, inputFormatClass: Class[_ <: InputFormat[K, V]], @@ -59,17 +72,19 @@ class NewHadoopRDD[K, V]( private val confBroadcast = sc.broadcast(new SerializableWritable(conf)) // private val serializableConf = new SerializableWritable(conf) - private val jobtrackerId: String = { + private val jobTrackerId: String = { val formatter = new SimpleDateFormat("yyyyMMddHHmm") formatter.format(new Date()) } - @transient private val jobId = new JobID(jobtrackerId, id) + @transient protected val jobId = new JobID(jobTrackerId, id) override def getPartitions: Array[Partition] = { val inputFormat = inputFormatClass.newInstance - if (inputFormat.isInstanceOf[Configurable]) { - inputFormat.asInstanceOf[Configurable].setConf(conf) + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => } val jobContext = newJobContext(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray @@ -85,11 +100,13 @@ class NewHadoopRDD[K, V]( val split = theSplit.asInstanceOf[NewHadoopPartition] logInfo("Input split: " + split.serializableHadoopSplit) val conf = confBroadcast.value.value - val attemptId = newTaskAttemptID(jobtrackerId, id, isMap = true, split.index, 0) + val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) val format = inputFormatClass.newInstance - if (format.isInstanceOf[Configurable]) { - format.asInstanceOf[Configurable].setConf(conf) + format match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => } val reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) @@ -135,3 +152,30 @@ class NewHadoopRDD[K, V]( def getConf: Configuration = confBroadcast.value.value } +private[spark] class WholeTextFileRDD( + sc : SparkContext, + inputFormatClass: Class[_ <: WholeTextFileInputFormat], + keyClass: Class[String], + valueClass: Class[String], + @transient conf: Configuration, + minSplits: Int) + extends NewHadoopRDD[String, String](sc, inputFormatClass, keyClass, valueClass, conf) { + + override def getPartitions: Array[Partition] = { + val inputFormat = inputFormatClass.newInstance + inputFormat match { + case configurable: Configurable => + configurable.setConf(conf) + case _ => + } + val jobContext = newJobContext(conf, jobId) + inputFormat.setMaxSplitSize(jobContext, minSplits) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Partition](rawSplits.size) + for (i <- 0 until rawSplits.size) { + result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + } + result + } +} + diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 14386ff5b9127..343e4325c0ef0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -39,6 +39,7 @@ RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil} import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} import org.apache.spark._ +import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.SparkHadoopWriter import org.apache.spark.Partitioner.defaultPartitioner @@ -201,9 +202,11 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) def countByKey(): Map[K, Long] = self.map(_._1).countByValue() /** - * (Experimental) Approximate version of countByKey that can return a partial result if it does + * :: Experimental :: + * Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ + @Experimental def countByKeyApprox(timeout: Long, confidence: Double = 0.95) : PartialResult[Map[K, BoundedDouble]] = { self.map(_._1).countByValueApprox(timeout, confidence) @@ -261,7 +264,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Allows controlling the * partitioning of the resulting key-value pair RDD by passing a Partitioner. */ - def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = { + def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = { // groupByKey shouldn't use map side combine because map side combine does not // reduce the amount of data shuffled and requires all map side data be inserted // into a hash table, leading to more objects in the old gen. @@ -270,14 +273,14 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) def mergeCombiners(c1: ArrayBuffer[V], c2: ArrayBuffer[V]) = c1 ++ c2 val bufs = combineByKey[ArrayBuffer[V]]( createCombiner _, mergeValue _, mergeCombiners _, partitioner, mapSideCombine=false) - bufs.asInstanceOf[RDD[(K, Seq[V])]] + bufs.mapValues(_.toIterable) } /** * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with into `numPartitions` partitions. */ - def groupByKey(numPartitions: Int): RDD[(K, Seq[V])] = { + def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = { groupByKey(new HashPartitioner(numPartitions)) } @@ -298,7 +301,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) */ def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = { this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => - for (v <- vs.iterator; w <- ws.iterator) yield (v, w) + for (v <- vs; w <- ws) yield (v, w) } } @@ -311,9 +314,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = { this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => if (ws.isEmpty) { - vs.iterator.map(v => (v, None)) + vs.map(v => (v, None)) } else { - for (v <- vs.iterator; w <- ws.iterator) yield (v, Some(w)) + for (v <- vs; w <- ws) yield (v, Some(w)) } } } @@ -328,9 +331,9 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) : RDD[(K, (Option[V], W))] = { this.cogroup(other, partitioner).flatMapValues { case (vs, ws) => if (vs.isEmpty) { - ws.iterator.map(w => (None, w)) + ws.map(w => (None, w)) } else { - for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), w) + for (v <- vs; w <- ws) yield (Some(v), w) } } } @@ -358,7 +361,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * Group the values for each key in the RDD into a single sequence. Hash-partitions the * resulting RDD with the existing partitioner/parallelism level. */ - def groupByKey(): RDD[(K, Seq[V])] = { + def groupByKey(): RDD[(K, Iterable[V])] = { groupByKey(defaultPartitioner(self)) } @@ -453,7 +456,8 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = { + def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner) + : RDD[(K, (Iterable[V], Iterable[W]))] = { if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { throw new SparkException("Default partitioner cannot partition array keys.") } @@ -468,13 +472,15 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { if (partitioner.isInstanceOf[HashPartitioner] && getKeyClass().isArray) { throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) cg.mapValues { case Seq(vs, w1s, w2s) => - (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]]) + (vs.asInstanceOf[Seq[V]], + w1s.asInstanceOf[Seq[W1]], + w2s.asInstanceOf[Seq[W2]]) } } @@ -482,7 +488,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { + def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, defaultPartitioner(self, other)) } @@ -491,7 +497,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } @@ -499,7 +505,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the * list of values for that key in `this` as well as `other`. */ - def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Seq[V], Seq[W]))] = { + def cogroup[W](other: RDD[(K, W)], numPartitions: Int): RDD[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, new HashPartitioner(numPartitions)) } @@ -508,18 +514,18 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) * tuple with the list of values for that key in `this`, `other1` and `other2`. */ def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numPartitions: Int) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { cogroup(other1, other2, new HashPartitioner(numPartitions)) } /** Alias for cogroup. */ - def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { + def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, defaultPartitioner(self, other)) } /** Alias for cogroup. */ def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)]) - : RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = { + : RDD[(K, (Iterable[V], Iterable[W1], Iterable[W2]))] = { cogroup(other1, other2, defaultPartitioner(self, other1, other2)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index b0440ca7f32cf..f781a8d776f2a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -20,8 +20,10 @@ package org.apache.spark.rdd import scala.reflect.ClassTag import org.apache.spark.{NarrowDependency, Partition, TaskContext} +import org.apache.spark.annotation.DeveloperApi -class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends Partition { +private[spark] class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) + extends Partition { override val index = idx } @@ -30,7 +32,7 @@ class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition) extends * Represents a dependency between the PartitionPruningRDD and its parent. In this * case, the child RDD contains a subset of partitions of the parents'. */ -class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) +private[spark] class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean) extends NarrowDependency[T](rdd) { @transient @@ -45,11 +47,13 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo /** + * :: DeveloperApi :: * A RDD used to prune RDD partitions/partitions so we can avoid launching tasks on * all partitions. An example use case: If we know the RDD is partitioned by range, * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. */ +@DeveloperApi class PartitionPruningRDD[T: ClassTag]( @transient prev: RDD[T], @transient partitionFilterFunc: Int => Boolean) @@ -63,6 +67,7 @@ class PartitionPruningRDD[T: ClassTag]( } +@DeveloperApi object PartitionPruningRDD { /** diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index a84357b38414e..0c2cd7a24783b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -33,7 +33,7 @@ class PartitionerAwareUnionRDDPartition( val idx: Int ) extends Partition { var parents = rdds.map(_.partitions(idx)).toArray - + override val index = idx override def hashCode(): Int = idx diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala index ce4c0d382baab..b4e3bb5d75e17 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala @@ -42,7 +42,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long) * @tparam T input RDD item type * @tparam U sampled RDD item type */ -class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( +private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag]( prev: RDD[T], sampler: RandomSampler[T, U], @transient seed: Long = System.nanoTime) diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 4250a9d02f764..e441d4a40ccd2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -17,6 +17,9 @@ package org.apache.spark.rdd +import java.io.File +import java.io.FilenameFilter +import java.io.IOException import java.io.PrintWriter import java.util.StringTokenizer @@ -27,18 +30,20 @@ import scala.io.Source import scala.reflect.ClassTag import org.apache.spark.{Partition, SparkEnv, TaskContext} +import org.apache.spark.util.Utils /** * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. */ -class PipedRDD[T: ClassTag]( +private[spark] class PipedRDD[T: ClassTag]( prev: RDD[T], command: Seq[String], envVars: Map[String, String], printPipeContext: (String => Unit) => Unit, - printRDDElement: (T, String => Unit) => Unit) + printRDDElement: (T, String => Unit) => Unit, + separateWorkingDir: Boolean) extends RDD[String](prev) { // Similar to Runtime.exec(), if we are given a single string, split it into words @@ -48,12 +53,24 @@ class PipedRDD[T: ClassTag]( command: String, envVars: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null) = - this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement) + printRDDElement: (T, String => Unit) => Unit = null, + separateWorkingDir: Boolean = false) = + this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement, + separateWorkingDir) override def getPartitions: Array[Partition] = firstParent[T].partitions + /** + * A FilenameFilter that accepts anything that isn't equal to the name passed in. + * @param name of file or directory to leave out + */ + class NotEqualsFileNameFilter(filterName: String) extends FilenameFilter { + def accept(dir: File, name: String): Boolean = { + !name.equals(filterName) + } + } + override def compute(split: Partition, context: TaskContext): Iterator[String] = { val pb = new ProcessBuilder(command) // Add the environmental variables to the process. @@ -67,6 +84,38 @@ class PipedRDD[T: ClassTag]( currentEnvVars.putAll(hadoopSplit.getPipeEnvVars()) } + // When spark.worker.separated.working.directory option is turned on, each + // task will be run in separate directory. This should be resolve file + // access conflict issue + val taskDirectory = "./tasks/" + java.util.UUID.randomUUID.toString + var workInTaskDirectory = false + logDebug("taskDirectory = " + taskDirectory) + if (separateWorkingDir == true) { + val currentDir = new File(".") + logDebug("currentDir = " + currentDir.getAbsolutePath()) + val taskDirFile = new File(taskDirectory) + taskDirFile.mkdirs() + + try { + val tasksDirFilter = new NotEqualsFileNameFilter("tasks") + + // Need to add symlinks to jars, files, and directories. On Yarn we could have + // directories and other files not known to the SparkContext that were added via the + // Hadoop distributed cache. We also don't want to symlink to the /tasks directories we + // are creating here. + for (file <- currentDir.list(tasksDirFilter)) { + val fileWithDir = new File(currentDir, file) + Utils.symlink(new File(fileWithDir.getAbsolutePath()), + new File(taskDirectory + "/" + fileWithDir.getName())) + } + pb.directory(taskDirFile) + workInTaskDirectory = true + } catch { + case e: Exception => logError("Unable to setup task working directory: " + e.getMessage + + " (" + taskDirectory + ")") + } + } + val proc = pb.start() val env = SparkEnv.get @@ -112,6 +161,15 @@ class PipedRDD[T: ClassTag]( if (exitStatus != 0) { throw new Exception("Subprocess exited with status " + exitStatus) } + + // cleanup task working directory if used + if (workInTaskDirectory == true) { + scala.util.control.Exception.ignoring(classOf[IOException]) { + Utils.deleteRecursively(new File(taskDirectory)) + } + logDebug("Removed task working directory " + taskDirectory) + } + false } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ce2b8ac27206b..891efccf23b6a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -20,12 +20,10 @@ package org.apache.spark.rdd import java.util.Random import scala.collection.Map -import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer import scala.reflect.{classTag, ClassTag} import com.clearspring.analytics.stream.cardinality.HyperLogLog -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable @@ -35,6 +33,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext._ +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -42,6 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, SerializableHyperLogLog, Utils} +import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler} /** @@ -86,22 +86,34 @@ abstract class RDD[T: ClassTag]( // Methods that should be implemented by subclasses of RDD // ======================================================================= - /** Implemented by subclasses to compute a given partition. */ + /** + * :: DeveloperApi :: + * Implemented by subclasses to compute a given partition. + */ + @DeveloperApi def compute(split: Partition, context: TaskContext): Iterator[T] /** + * :: DeveloperApi :: * Implemented by subclasses to return the set of partitions in this RDD. This method will only * be called once, so it is safe to implement a time-consuming computation in it. */ + @DeveloperApi protected def getPartitions: Array[Partition] /** + * :: DeveloperApi :: * Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only * be called once, so it is safe to implement a time-consuming computation in it. */ + @DeveloperApi protected def getDependencies: Seq[Dependency[_]] = deps - /** Optionally overridden by subclasses to specify placement preferences. */ + /** + * :: DeveloperApi :: + * Optionally overridden by subclasses to specify placement preferences. + */ + @DeveloperApi protected def getPreferredLocations(split: Partition): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ @@ -138,6 +150,8 @@ abstract class RDD[T: ClassTag]( "Cannot change storage level of an RDD after it was already assigned a level") } sc.persistRDD(this) + // Register the RDD with the ContextCleaner for automatic GC-based cleanup + sc.cleaner.foreach(_.registerRDDForCleanup(this)) storageLevel = newLevel this } @@ -156,7 +170,7 @@ abstract class RDD[T: ClassTag]( */ def unpersist(blocking: Boolean = true): RDD[T] = { logInfo("Removing RDD " + id + " from persistence list") - sc.unpersistRDD(this, blocking) + sc.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE this } @@ -436,20 +450,20 @@ abstract class RDD[T: ClassTag]( /** * Return an RDD of grouped items. */ - def groupBy[K: ClassTag](f: T => K): RDD[(K, Seq[T])] = + def groupBy[K: ClassTag](f: T => K): RDD[(K, Iterable[T])] = groupBy[K](f, defaultPartitioner(this)) /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ - def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = + def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Iterable[T])] = groupBy(f, new HashPartitioner(numPartitions)) /** * Return an RDD of grouped items. */ - def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { + def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Iterable[T])] = { val cleanF = sc.clean(f) this.map(t => (cleanF(t), t)).groupByKey(p) } @@ -481,16 +495,19 @@ abstract class RDD[T: ClassTag]( * instead of constructing a huge String to concat all the elements: * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = * for (e <- record._2){f(e)} + * @param separateWorkingDir Use separate working directories for each task. * @return the result RDD */ def pipe( command: Seq[String], env: Map[String, String] = Map(), printPipeContext: (String => Unit) => Unit = null, - printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = { + printRDDElement: (T, String => Unit) => Unit = null, + separateWorkingDir: Boolean = false): RDD[String] = { new PipedRDD(this, command, env, if (printPipeContext ne null) sc.clean(printPipeContext) else null, - if (printRDDElement ne null) sc.clean(printRDDElement) else null) + if (printRDDElement ne null) sc.clean(printRDDElement) else null, + separateWorkingDir) } /** @@ -513,9 +530,11 @@ abstract class RDD[T: ClassTag]( } /** + * :: DeveloperApi :: * Return a new RDD by applying a function to each partition of this RDD. This is a variant of * mapPartitions that also passes the TaskContext into the closure. */ + @DeveloperApi def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { @@ -658,6 +677,18 @@ abstract class RDD[T: ClassTag]( Array.concat(results: _*) } + /** + * Return an iterator that contains all of the elements in this RDD. + * + * The iterator will consume as much memory as the largest partition in this RDD. + */ + def toLocalIterator: Iterator[T] = { + def collectPartition(p: Int): Array[T] = { + sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head + } + (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) + } + /** * Return an array that contains all of the elements in this RDD. */ @@ -775,9 +806,11 @@ abstract class RDD[T: ClassTag]( def count(): Long = sc.runJob(this, Utils.getIteratorSize _).sum /** - * (Experimental) Approximate version of count() that returns a potentially incomplete result + * :: Experimental :: + * Approximate version of count() that returns a potentially incomplete result * within a timeout, even if not all tasks have finished. */ + @Experimental def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = { val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) => var result = 0L @@ -800,29 +833,31 @@ abstract class RDD[T: ClassTag]( throw new SparkException("countByValue() does not support arrays") } // TODO: This should perhaps be distributed by default. - def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = { - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) + def countPartition(iter: Iterator[T]): Iterator[OpenHashMap[T,Long]] = { + val map = new OpenHashMap[T,Long] + iter.foreach { + t => map.changeValue(t, 1L, _ + 1L) } Iterator(map) } - def mergeMaps(m1: OLMap[T], m2: OLMap[T]): OLMap[T] = { - val iter = m2.object2LongEntrySet.fastIterator() - while (iter.hasNext) { - val entry = iter.next() - m1.put(entry.getKey, m1.getLong(entry.getKey) + entry.getLongValue) + def mergeMaps(m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]): OpenHashMap[T,Long] = { + m2.foreach { case (key, value) => + m1.changeValue(key, value, _ + value) } m1 } val myResult = mapPartitions(countPartition).reduce(mergeMaps) - myResult.asInstanceOf[java.util.Map[T, Long]] // Will be wrapped as a Scala mutable Map + // Convert to a Scala mutable map + val mutableResult = scala.collection.mutable.Map[T,Long]() + myResult.foreach { case (k, v) => mutableResult.put(k, v) } + mutableResult } /** - * (Experimental) Approximate version of countByValue(). + * :: Experimental :: + * Approximate version of countByValue(). */ + @Experimental def countByValueApprox( timeout: Long, confidence: Double = 0.95 @@ -830,11 +865,10 @@ abstract class RDD[T: ClassTag]( if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } - val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => - val map = new OLMap[T] - while (iter.hasNext) { - val v = iter.next() - map.put(v, map.getLong(v) + 1L) + val countPartition: (TaskContext, Iterator[T]) => OpenHashMap[T,Long] = { (ctx, iter) => + val map = new OpenHashMap[T,Long] + iter.foreach { + t => map.changeValue(t, 1L, _ + 1L) } map } @@ -843,6 +877,7 @@ abstract class RDD[T: ClassTag]( } /** + * :: Experimental :: * Return approximate number of distinct elements in the RDD. * * The accuracy of approximation can be controlled through the relative standard deviation @@ -850,6 +885,7 @@ abstract class RDD[T: ClassTag]( * more accurate counts but increase the memory footprint and vise versa. The default value of * relativeSD is 0.05. */ + @Experimental def countApproxDistinct(relativeSD: Double = 0.05): Long = { val zeroCounter = new SerializableHyperLogLog(new HyperLogLog(relativeSD)) aggregate(zeroCounter)(_.add(_), _.merge(_)).value.cardinality() @@ -1126,5 +1162,4 @@ abstract class RDD[T: ClassTag]( def toJavaRDD() : JavaRDD[T] = { new JavaRDD(this)(elementClassTag) } - } diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala index 4ceea557f569c..b097c30f8c231 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -33,7 +33,7 @@ class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition } @deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0") -class SampledRDD[T: ClassTag]( +private[spark] class SampledRDD[T: ClassTag]( prev: RDD[T], withReplacement: Boolean, frac: Double, diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 02660ea6a45c5..802b0bdfb2d59 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { @@ -28,12 +29,14 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { } /** + * :: DeveloperApi :: * The resulting RDD from a shuffle (e.g. repartitioning of data). * @param prev the parent RDD. * @param part the partitioner used to partition the RDD * @tparam K the key class. * @tparam V the value class. */ +@DeveloperApi class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( @transient var prev: RDD[P], part: Partitioner) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index a447030752096..21c6e07d69f90 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext} +import org.apache.spark.annotation.DeveloperApi private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int) extends Partition { @@ -43,6 +44,7 @@ private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitInd } } +@DeveloperApi class UnionRDD[T: ClassTag]( sc: SparkContext, @transient var rdds: Seq[RDD[T]]) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index b56643444aa40..f3d30f6c9b32f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -41,7 +41,7 @@ private[spark] class ZippedPartitionsPartition( } } -abstract class ZippedPartitionsBaseRDD[V: ClassTag]( +private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag]( sc: SparkContext, var rdds: Seq[RDD[_]], preservesPartitioning: Boolean = false) @@ -74,7 +74,7 @@ abstract class ZippedPartitionsBaseRDD[V: ClassTag]( } } -class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( +private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( sc: SparkContext, f: (Iterator[A], Iterator[B]) => Iterator[V], var rdd1: RDD[A], @@ -94,7 +94,7 @@ class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( } } -class ZippedPartitionsRDD3 +private[spark] class ZippedPartitionsRDD3 [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], @@ -119,7 +119,7 @@ class ZippedPartitionsRDD3 } } -class ZippedPartitionsRDD4 +private[spark] class ZippedPartitionsRDD4 [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala index 2119e76f0e032..b8110ffc42f2d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala @@ -44,7 +44,7 @@ private[spark] class ZippedPartition[T: ClassTag, U: ClassTag]( } } -class ZippedRDD[T: ClassTag, U: ClassTag]( +private[spark] class ZippedRDD[T: ClassTag, U: ClassTag]( sc: SparkContext, var rdd1: RDD[T], var rdd2: RDD[U]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala new file mode 100644 index 0000000000000..c1001227151a5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +/** + * A simple listener for application events. + * + * This listener expects to hear events from a single application only. If events + * from multiple applications are seen, the behavior is unspecified. + */ +private[spark] class ApplicationEventListener extends SparkListener { + var appName = "" + var sparkUser = "" + var startTime = -1L + var endTime = -1L + + def applicationStarted = startTime != -1 + + def applicationCompleted = endTime != -1 + + def applicationDuration: Long = { + val difference = endTime - startTime + if (applicationStarted && applicationCompleted && difference > 0) difference else -1L + } + + override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { + appName = applicationStart.appName + startTime = applicationStart.time + sparkUser = applicationStart.sparkUser + } + + override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { + endTime = applicationEnd.time + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4fce47e1ee8de..c6cbf14e20069 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -32,7 +32,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId} -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} +import org.apache.spark.util.Utils /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -80,13 +80,13 @@ class DAGScheduler( private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) - private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]] - private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]] - private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage] - private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] - private[scheduler] val stageIdToActiveJob = new HashMap[Int, ActiveJob] + private[scheduler] val jobIdToStageIds = new HashMap[Int, HashSet[Int]] + private[scheduler] val stageIdToJobIds = new HashMap[Int, HashSet[Int]] + private[scheduler] val stageIdToStage = new HashMap[Int, Stage] + private[scheduler] val shuffleToMapStage = new HashMap[Int, Stage] + private[scheduler] val jobIdToActiveJob = new HashMap[Int, ActiveJob] private[scheduler] val resultStageToJob = new HashMap[Stage, ActiveJob] - private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] + private[scheduler] val stageToInfos = new HashMap[Stage, StageInfo] // Stages we need to run whose parents aren't done private[scheduler] val waitingStages = new HashSet[Stage] @@ -98,7 +98,7 @@ class DAGScheduler( private[scheduler] val failedStages = new HashSet[Stage] // Missing tasks from each stage - private[scheduler] val pendingTasks = new TimeStampedHashMap[Stage, HashSet[Task[_]]] + private[scheduler] val pendingTasks = new HashMap[Stage, HashSet[Task[_]]] private[scheduler] val activeJobs = new HashSet[ActiveJob] @@ -113,9 +113,6 @@ class DAGScheduler( // stray messages to detect. private val failedEpoch = new HashMap[String, Long] - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup, env.conf) - taskScheduler.setDAGScheduler(this) /** @@ -258,7 +255,7 @@ class DAGScheduler( : Stage = { val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) - if (mapOutputTracker.has(shuffleDep.shuffleId)) { + if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) { val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) val locs = MapOutputTracker.deserializeMapStatuses(serLocs) for (i <- 0 until locs.size) { @@ -345,22 +342,24 @@ class DAGScheduler( } /** - * Removes job and any stages that are not needed by any other job. Returns the set of ids for - * stages that were removed. The associated tasks for those stages need to be cancelled if we - * got here via job cancellation. + * Removes state for job and any stages that are not needed by any other job. Does not + * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. + * + * @param job The job whose state to cleanup. + * @param resultStage Specifies the result stage for the job; if set to None, this method + * searches resultStagesToJob to find and cleanup the appropriate result stage. */ - private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { - val registeredStages = jobIdToStageIds(jobId) - val independentStages = new HashSet[Int]() - if (registeredStages.isEmpty) { - logError("No stages registered for job " + jobId) + private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) { + val registeredStages = jobIdToStageIds.get(job.jobId) + if (registeredStages.isEmpty || registeredStages.get.isEmpty) { + logError("No stages registered for job " + job.jobId) } else { - stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { + stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach { case (stageId, jobSet) => - if (!jobSet.contains(jobId)) { + if (!jobSet.contains(job.jobId)) { logError( "Job %d not registered for stage %d even though that stage was registered for the job" - .format(jobId, stageId)) + .format(job.jobId, stageId)) } else { def removeStage(stageId: Int) { // data structures based on Stage @@ -390,27 +389,35 @@ class DAGScheduler( stageIdToStage -= stageId stageIdToJobIds -= stageId + ShuffleMapTask.removeStage(stageId) + ResultTask.removeStage(stageId) + logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } - jobSet -= jobId + jobSet -= job.jobId if (jobSet.isEmpty) { // no other job needs this stage - independentStages += stageId removeStage(stageId) } } } } - independentStages.toSet - } - - private def jobIdToStageIdsRemove(jobId: Int) { - if (!jobIdToStageIds.contains(jobId)) { - logDebug("Trying to remove unregistered job " + jobId) + jobIdToStageIds -= job.jobId + jobIdToActiveJob -= job.jobId + activeJobs -= job + + if (resultStage.isEmpty) { + // Clean up result stages. + val resultStagesForJob = resultStageToJob.keySet.filter( + stage => resultStageToJob(stage).jobId == job.jobId) + if (resultStagesForJob.size != 1) { + logWarning( + s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)") + } + resultStageToJob --= resultStagesForJob } else { - removeJobAndIndependentStages(jobId) - jobIdToStageIds -= jobId + resultStageToJob -= resultStage.get } } @@ -460,7 +467,7 @@ class DAGScheduler( val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties) waiter.awaitResult() match { case JobSucceeded => {} - case JobFailed(exception: Exception, _) => + case JobFailed(exception: Exception) => logInfo("Failed to run " + callSite) throw exception } @@ -504,6 +511,13 @@ class DAGScheduler( eventProcessActor ! AllJobsCancelled } + /** + * Cancel all jobs associated with a running or scheduled stage. + */ + def cancelStage(stageId: Int) { + eventProcessActor ! StageCancelled(stageId) + } + /** * Process one event retrieved from the event processing actor. * @@ -536,7 +550,7 @@ class DAGScheduler( listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties)) runLocally(job) } else { - stageIdToActiveJob(jobId) = job + jobIdToActiveJob(jobId) = job activeJobs += job resultStageToJob(finalStage) = job listenerBus.post( @@ -544,6 +558,9 @@ class DAGScheduler( submitStage(finalStage) } + case StageCancelled(stageId) => + handleStageCancellation(stageId) + case JobCancelled(jobId) => handleJobCancellation(jobId) @@ -553,13 +570,15 @@ class DAGScheduler( val activeInGroup = activeJobs.filter(activeJob => groupId == activeJob.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) val jobIds = activeInGroup.map(_.jobId) - jobIds.foreach(handleJobCancellation) + jobIds.foreach(jobId => handleJobCancellation(jobId, + "as part of cancelled job group %s".format(groupId))) case AllJobsCancelled => // Cancel all running jobs. - runningStages.map(_.jobId).foreach(handleJobCancellation) + runningStages.map(_.jobId).foreach(jobId => handleJobCancellation(jobId, + "as part of cancellation of all jobs")) activeJobs.clear() // These should already be empty by this point, - stageIdToActiveJob.clear() // but just in case we lost track of some jobs... + jobIdToActiveJob.clear() // but just in case we lost track of some jobs... case ExecutorAdded(execId, host) => handleExecutorAdded(execId, host) @@ -569,7 +588,6 @@ class DAGScheduler( case BeginEvent(task, taskInfo) => for ( - job <- stageIdToActiveJob.get(task.stageId); stage <- stageIdToStage.get(task.stageId); stageInfo <- stageToInfos.get(stage) ) { @@ -607,7 +625,16 @@ class DAGScheduler( for (job <- activeJobs) { val error = new SparkException("Job cancelled because SparkContext was shut down") job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, -1))) + // Tell the listeners that all of the running stages have ended. Don't bother + // cancelling the stages because if the DAG scheduler is stopped, the entire application + // is in the process of getting stopped. + val stageFailedMessage = "Stage cancelled because SparkContext was shut down" + runningStages.foreach { stage => + val info = stageToInfos(stage) + info.stageFailed(stageFailedMessage) + listenerBus.post(SparkListenerStageCompleted(info)) + } + listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) } return true } @@ -677,7 +704,7 @@ class DAGScheduler( } } catch { case e: Exception => - jobResult = JobFailed(e, job.finalStage.id) + jobResult = JobFailed(e) job.listener.jobFailed(e) } finally { val s = job.finalStage @@ -697,7 +724,7 @@ class DAGScheduler( private def activeJobForStage(stage: Stage): Option[Int] = { if (stageIdToJobIds.contains(stage.id)) { val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted - jobsThatUseStage.find(stageIdToActiveJob.contains) + jobsThatUseStage.find(jobIdToActiveJob.contains) } else { None } @@ -750,8 +777,8 @@ class DAGScheduler( } } - val properties = if (stageIdToActiveJob.contains(jobId)) { - stageIdToActiveJob(stage.jobId).properties + val properties = if (jobIdToActiveJob.contains(jobId)) { + jobIdToActiveJob(stage.jobId).properties } else { // this stage will be assigned to "default" pool null @@ -827,11 +854,8 @@ class DAGScheduler( job.numFinished += 1 // If the whole job has finished, remove it if (job.numFinished == job.numPartitions) { - stageIdToActiveJob -= stage.jobId - activeJobs -= job - resultStageToJob -= stage markStageAsFinished(stage) - jobIdToStageIdsRemove(job.jobId) + cleanupStateForJobAndIndependentStages(job, Some(stage)) listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded)) } job.listener.taskSucceeded(rt.outputId, event.result) @@ -979,19 +1003,23 @@ class DAGScheduler( } } - private def handleJobCancellation(jobId: Int) { + private def handleStageCancellation(stageId: Int) { + if (stageIdToJobIds.contains(stageId)) { + val jobsThatUseStage: Array[Int] = stageIdToJobIds(stageId).toArray + jobsThatUseStage.foreach(jobId => { + handleJobCancellation(jobId, "because Stage %s was cancelled".format(stageId)) + }) + } else { + logInfo("No active jobs to kill for Stage " + stageId) + } + } + + private def handleJobCancellation(jobId: Int, reason: String = "") { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - val independentStages = removeJobAndIndependentStages(jobId) - independentStages.foreach(taskScheduler.cancelTasks) - val error = new SparkException("Job %d cancelled".format(jobId)) - val job = stageIdToActiveJob(jobId) - job.listener.jobFailed(error) - jobIdToStageIds -= jobId - activeJobs -= job - stageIdToActiveJob -= jobId - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, job.finalStage.id))) + failJobAndIndependentStages(jobIdToActiveJob(jobId), + "Job %d cancelled %s".format(jobId, reason), None) } } @@ -1008,19 +1036,57 @@ class DAGScheduler( stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) - val error = new SparkException("Job aborted: " + reason) - job.listener.jobFailed(error) - jobIdToStageIdsRemove(job.jobId) - stageIdToActiveJob -= resultStage.jobId - activeJobs -= job - resultStageToJob -= resultStage - listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, failedStage.id))) + failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", + Some(resultStage)) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") } } + /** + * Fails a job and all stages that are only used by that job, and cleans up relevant state. + * + * @param resultStage The result stage for the job, if known. Used to cleanup state for the job + * slightly more efficiently than when not specified. + */ + private def failJobAndIndependentStages(job: ActiveJob, failureReason: String, + resultStage: Option[Stage]) { + val error = new SparkException(failureReason) + job.listener.jobFailed(error) + + // Cancel all independent, running stages. + val stages = jobIdToStageIds(job.jobId) + if (stages.isEmpty) { + logError("No stages registered for job " + job.jobId) + } + stages.foreach { stageId => + val jobsForStage = stageIdToJobIds.get(stageId) + if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) { + logError( + "Job %d not registered for stage %d even though that stage was registered for the job" + .format(job.jobId, stageId)) + } else if (jobsForStage.get.size == 1) { + if (!stageIdToStage.contains(stageId)) { + logError("Missing Stage for stage with id $stageId") + } else { + // This is the only job that uses this stage, so fail the stage if it is running. + val stage = stageIdToStage(stageId) + if (runningStages.contains(stage)) { + taskScheduler.cancelTasks(stageId) + val stageInfo = stageToInfos(stage) + stageInfo.stageFailed(failureReason) + listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage))) + } + } + } + } + + cleanupStateForJobAndIndependentStages(job, resultStage) + + listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error))) + } + /** * Return true if one of stage's ancestors is target. */ @@ -1085,26 +1151,10 @@ class DAGScheduler( Nil } - private def cleanup(cleanupTime: Long) { - Map( - "stageIdToStage" -> stageIdToStage, - "shuffleToMapStage" -> shuffleToMapStage, - "pendingTasks" -> pendingTasks, - "stageToInfos" -> stageToInfos, - "jobIdToStageIds" -> jobIdToStageIds, - "stageIdToJobIds" -> stageIdToJobIds). - foreach { case (s, t) => - val sizeBefore = t.size - t.clearOldValues(cleanupTime) - logInfo("%s %d --> %d".format(s, sizeBefore, t.size)) - } - } - def stop() { if (eventProcessActor != null) { eventProcessActor ! StopDAGScheduler } - metadataCleaner.cancel() taskScheduler.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 04c53d468465a..7367c08b5d324 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -44,6 +44,8 @@ private[scheduler] case class JobSubmitted( properties: Properties = null) extends DAGSchedulerEvent +private[scheduler] case class StageCancelled(stageId: Int) extends DAGSchedulerEvent + private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent @@ -54,7 +56,7 @@ private[scheduler] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent private[scheduler] -case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent +case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent private[scheduler] case class CompletionEvent( task: Task[_], diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 217f8825c2ae9..b983c16af14f4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -17,11 +17,14 @@ package org.apache.spark.scheduler +import scala.collection.mutable + +import org.apache.hadoop.fs.{FileSystem, Path} import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{JsonProtocol, FileLogger} +import org.apache.spark.util.{FileLogger, JsonProtocol} /** * A SparkListener that logs events to persistent storage. @@ -36,6 +39,8 @@ import org.apache.spark.util.{JsonProtocol, FileLogger} private[spark] class EventLoggingListener(appName: String, conf: SparkConf) extends SparkListener with Logging { + import EventLoggingListener._ + private val shouldCompress = conf.getBoolean("spark.eventLog.compress", false) private val shouldOverwrite = conf.getBoolean("spark.eventLog.overwrite", false) private val outputBufferSize = conf.getInt("spark.eventLog.buffer.kb", 100) * 1024 @@ -46,17 +51,21 @@ private[spark] class EventLoggingListener(appName: String, conf: SparkConf) private val logger = new FileLogger(logDir, conf, outputBufferSize, shouldCompress, shouldOverwrite) - // Information needed to replay the events logged by this listener later - val info = { - val compressionCodec = if (shouldCompress) { - Some(conf.get("spark.io.compression.codec", CompressionCodec.DEFAULT_COMPRESSION_CODEC)) - } else None - EventLoggingInfo(logDir, compressionCodec) + /** + * Begin logging events. + * If compression is used, log a file that indicates which compression library is used. + */ + def start() { + logInfo("Logging events to %s".format(logDir)) + if (shouldCompress) { + val codec = conf.get("spark.io.compression.codec", CompressionCodec.DEFAULT_COMPRESSION_CODEC) + logger.newFile(COMPRESSION_CODEC_PREFIX + codec) + } + logger.newFile(SPARK_VERSION_PREFIX + SparkContext.SPARK_VERSION) + logger.newFile(LOG_PREFIX + logger.fileIndex) } - logInfo("Logging events to %s".format(logDir)) - - /** Log the event as JSON */ + /** Log the event as JSON. */ private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) { val eventJson = compact(render(JsonProtocol.sparkEventToJson(event))) logger.logLine(eventJson) @@ -90,9 +99,118 @@ private[spark] class EventLoggingListener(appName: String, conf: SparkConf) logEvent(event, flushLogger = true) override def onUnpersistRDD(event: SparkListenerUnpersistRDD) = logEvent(event, flushLogger = true) + override def onApplicationStart(event: SparkListenerApplicationStart) = + logEvent(event, flushLogger = true) + override def onApplicationEnd(event: SparkListenerApplicationEnd) = + logEvent(event, flushLogger = true) + + /** + * Stop logging events. + * In addition, create an empty special file to indicate application completion. + */ + def stop() = { + logger.newFile(APPLICATION_COMPLETE) + logger.stop() + } +} + +private[spark] object EventLoggingListener extends Logging { + val LOG_PREFIX = "EVENT_LOG_" + val SPARK_VERSION_PREFIX = "SPARK_VERSION_" + val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_" + val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" + + // A cache for compression codecs to avoid creating the same codec many times + private val codecMap = new mutable.HashMap[String, CompressionCodec] + + def isEventLogFile(fileName: String): Boolean = { + fileName.startsWith(LOG_PREFIX) + } + + def isSparkVersionFile(fileName: String): Boolean = { + fileName.startsWith(SPARK_VERSION_PREFIX) + } + + def isCompressionCodecFile(fileName: String): Boolean = { + fileName.startsWith(COMPRESSION_CODEC_PREFIX) + } + + def isApplicationCompleteFile(fileName: String): Boolean = { + fileName == APPLICATION_COMPLETE + } + + def parseSparkVersion(fileName: String): String = { + if (isSparkVersionFile(fileName)) { + fileName.replaceAll(SPARK_VERSION_PREFIX, "") + } else "" + } + + def parseCompressionCodec(fileName: String): String = { + if (isCompressionCodecFile(fileName)) { + fileName.replaceAll(COMPRESSION_CODEC_PREFIX, "") + } else "" + } + + /** + * Parse the event logging information associated with the logs in the given directory. + * + * Specifically, this looks for event log files, the Spark version file, the compression + * codec file (if event logs are compressed), and the application completion file (if the + * application has run to completion). + */ + def parseLoggingInfo(logDir: Path, fileSystem: FileSystem): EventLoggingInfo = { + try { + val fileStatuses = fileSystem.listStatus(logDir) + val filePaths = + if (fileStatuses != null) { + fileStatuses.filter(!_.isDir).map(_.getPath).toSeq + } else { + Seq[Path]() + } + if (filePaths.isEmpty) { + logWarning("No files found in logging directory %s".format(logDir)) + } + EventLoggingInfo( + logPaths = filePaths.filter { path => isEventLogFile(path.getName) }, + sparkVersion = filePaths + .find { path => isSparkVersionFile(path.getName) } + .map { path => parseSparkVersion(path.getName) } + .getOrElse(""), + compressionCodec = filePaths + .find { path => isCompressionCodecFile(path.getName) } + .map { path => + val codec = EventLoggingListener.parseCompressionCodec(path.getName) + val conf = new SparkConf + conf.set("spark.io.compression.codec", codec) + codecMap.getOrElseUpdate(codec, CompressionCodec.createCodec(conf)) + }, + applicationComplete = filePaths.exists { path => isApplicationCompleteFile(path.getName) } + ) + } catch { + case t: Throwable => + logError("Exception in parsing logging info from directory %s".format(logDir), t) + EventLoggingInfo.empty + } + } - def stop() = logger.stop() + /** + * Parse the event logging information associated with the logs in the given directory. + */ + def parseLoggingInfo(logDir: String, fileSystem: FileSystem): EventLoggingInfo = { + parseLoggingInfo(new Path(logDir), fileSystem) + } } -// If compression is not enabled, compressionCodec is None -private[spark] case class EventLoggingInfo(logDir: String, compressionCodec: Option[String]) + +/** + * Information needed to process the event logs associated with an application. + */ +private[spark] case class EventLoggingInfo( + logPaths: Seq[Path], + sparkVersion: String, + compressionCodec: Option[CompressionCodec], + applicationComplete: Boolean = false) + +private[spark] object EventLoggingInfo { + def empty = EventLoggingInfo(Seq[Path](), "", None, applicationComplete = false) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 5555585c8b4cd..bac37bfdaa23f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -27,11 +27,14 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.util.ReflectionUtils import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil /** + * :: DeveloperApi :: * Parses and holds information about inputFormat (and files) specified as a parameter. */ +@DeveloperApi class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], val path: String) extends Logging { @@ -164,8 +167,7 @@ object InputFormatInfo { PS: I know the wording here is weird, hopefully it makes some sense ! */ - def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] - = { + def computePreferredLocations(formats: Seq[InputFormatInfo]): Map[String, Set[SplitInfo]] = { val nodeToSplit = new HashMap[String, HashSet[SplitInfo]] for (inputSplit <- formats) { @@ -178,6 +180,6 @@ object InputFormatInfo { } } - nodeToSplit + nodeToSplit.mapValues(_.toSet).toMap } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 5cecf9416b32c..713aebfa3ce00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -25,9 +25,11 @@ import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.HashMap import org.apache.spark._ +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics /** + * :: DeveloperApi :: * A logger class to record runtime information for jobs in Spark. This class outputs one log file * for each Spark job, containing tasks start/stop and shuffle information. JobLogger is a subclass * of SparkListener, use addSparkListener to add JobLogger to a SparkContext after the SparkContext @@ -38,7 +40,7 @@ import org.apache.spark.executor.TaskMetrics * to log application information as SparkListenerEvents. To enable this functionality, set * spark.eventLog.enabled to true. */ - +@DeveloperApi @deprecated("Log application information by setting spark.eventLog.enabled.", "1.0.0") class JobLogger(val user: String, val logDirName: String) extends SparkListener with Logging { @@ -191,7 +193,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener */ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { val stageId = stageCompleted.stageInfo.stageId - stageLogInfo(stageId, "STAGE_ID=%d STATUS=COMPLETED".format(stageId)) + if (stageCompleted.stageInfo.failureReason.isEmpty) { + stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED") + } else { + stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED") + } } /** @@ -227,7 +233,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener var info = "JOB_ID=" + jobId jobEnd.jobResult match { case JobSucceeded => info += " STATUS=SUCCESS" - case JobFailed(exception, _) => + case JobFailed(exception) => info += " STATUS=FAILED REASON=" exception.getMessage.split("\\s+").foreach(info += _ + "_") case _ => diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala index 3cf4e3077e4a4..4cd6cbe189aab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobResult.scala @@ -17,12 +17,17 @@ package org.apache.spark.scheduler +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * A result of a job in the DAGScheduler. */ -private[spark] sealed trait JobResult +@DeveloperApi +sealed trait JobResult -private[spark] case object JobSucceeded extends JobResult +@DeveloperApi +case object JobSucceeded extends JobResult -// A failed stage ID of -1 means there is not a particular stage that caused the failure -private[spark] case class JobFailed(exception: Exception, failedStageId: Int) extends JobResult +@DeveloperApi +private[spark] case class JobFailed(exception: Exception) extends JobResult diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 8007b5418741e..e9bfee2248e5b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -64,7 +64,7 @@ private[spark] class JobWaiter[T]( override def jobFailed(exception: Exception): Unit = synchronized { _jobFinished = true - jobResult = JobFailed(exception, -1) + jobResult = JobFailed(exception) this.notifyAll() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 353a48661b0f7..545fa453b7ccf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -1,101 +1,107 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import java.util.concurrent.LinkedBlockingQueue - -import org.apache.spark.Logging - -/** - * Asynchronously passes SparkListenerEvents to registered SparkListeners. - * - * Until start() is called, all posted events are only buffered. Only after this listener bus - * has started will events be actually propagated to all attached listeners. This listener bus - * is stopped when it receives a SparkListenerShutdown event, which is posted using stop(). - */ -private[spark] class LiveListenerBus extends SparkListenerBus with Logging { - - /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) - private var queueFullErrorMessageLogged = false - private var started = false - - /** - * Start sending events to attached listeners. - * - * This first sends out all buffered events posted before this listener bus has started, then - * listens for any additional events asynchronously while the listener bus is still running. - * This should only be called once. - */ - def start() { - if (started) { - throw new IllegalStateException("Listener bus already started!") - } - started = true - new Thread("SparkListenerBus") { - setDaemon(true) - override def run() { - while (true) { - val event = eventQueue.take - if (event == SparkListenerShutdown) { - // Get out of the while loop and shutdown the daemon thread - return - } - postToAll(event) - } - } - }.start() - } - - def post(event: SparkListenerEvent) { - val eventAdded = eventQueue.offer(event) - if (!eventAdded && !queueFullErrorMessageLogged) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with the " + - "rate at which tasks are being started by the scheduler.") - queueFullErrorMessageLogged = true - } - } - - /** - * Waits until there are no more events in the queue, or until the specified time has elapsed. - * Used for testing only. Returns true if the queue has emptied and false is the specified time - * elapsed before the queue emptied. - */ - def waitUntilEmpty(timeoutMillis: Int): Boolean = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!eventQueue.isEmpty) { - if (System.currentTimeMillis > finishTime) { - return false - } - /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify - * add overhead in the general case. */ - Thread.sleep(10) - } - true - } - - def stop() { - if (!started) { - throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!") - } - post(SparkListenerShutdown) - } -} +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import java.util.concurrent.LinkedBlockingQueue + +import org.apache.spark.Logging + +/** + * Asynchronously passes SparkListenerEvents to registered SparkListeners. + * + * Until start() is called, all posted events are only buffered. Only after this listener bus + * has started will events be actually propagated to all attached listeners. This listener bus + * is stopped when it receives a SparkListenerShutdown event, which is posted using stop(). + */ +private[spark] class LiveListenerBus extends SparkListenerBus with Logging { + + /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than + * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ + private val EVENT_QUEUE_CAPACITY = 10000 + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) + private var queueFullErrorMessageLogged = false + private var started = false + private val listenerThread = new Thread("SparkListenerBus") { + setDaemon(true) + override def run() { + while (true) { + val event = eventQueue.take + if (event == SparkListenerShutdown) { + // Get out of the while loop and shutdown the daemon thread + return + } + postToAll(event) + } + } + } + + // Exposed for testing + @volatile private[spark] var stopCalled = false + + /** + * Start sending events to attached listeners. + * + * This first sends out all buffered events posted before this listener bus has started, then + * listens for any additional events asynchronously while the listener bus is still running. + * This should only be called once. + */ + def start() { + if (started) { + throw new IllegalStateException("Listener bus already started!") + } + listenerThread.start() + started = true + } + + def post(event: SparkListenerEvent) { + val eventAdded = eventQueue.offer(event) + if (!eventAdded && !queueFullErrorMessageLogged) { + logError("Dropping SparkListenerEvent because no remaining room in event queue. " + + "This likely means one of the SparkListeners is too slow and cannot keep up with the " + + "rate at which tasks are being started by the scheduler.") + queueFullErrorMessageLogged = true + } + } + + /** + * Waits until there are no more events in the queue, or until the specified time has elapsed. + * Used for testing only. Returns true if the queue has emptied and false is the specified time + * elapsed before the queue emptied. + */ + def waitUntilEmpty(timeoutMillis: Int): Boolean = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!eventQueue.isEmpty) { + if (System.currentTimeMillis > finishTime) { + return false + } + /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify + * add overhead in the general case. */ + Thread.sleep(10) + } + true + } + + def stop() { + stopCalled = true + if (!started) { + throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!") + } + post(SparkListenerShutdown) + listenerThread.join() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala index db76178b65501..f89724d4ea196 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala @@ -17,72 +17,54 @@ package org.apache.spark.scheduler -import java.io.InputStream -import java.net.URI +import java.io.{BufferedInputStream, InputStream} import scala.io.Source -import it.unimi.dsi.fastutil.io.FastBufferedInputStream import org.apache.hadoop.fs.{Path, FileSystem} import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.Logging import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{JsonProtocol, Utils} +import org.apache.spark.util.JsonProtocol /** - * An EventBus that replays logged events from persisted storage + * A SparkListenerBus that replays logged events from persisted storage. + * + * This assumes the given paths are valid log files, where each line can be deserialized into + * exactly one SparkListenerEvent. */ -private[spark] class ReplayListenerBus(conf: SparkConf) extends SparkListenerBus with Logging { - private val compressed = conf.getBoolean("spark.eventLog.compress", false) +private[spark] class ReplayListenerBus( + logPaths: Seq[Path], + fileSystem: FileSystem, + compressionCodec: Option[CompressionCodec]) + extends SparkListenerBus with Logging { - // Only used if compression is enabled - private lazy val compressionCodec = CompressionCodec.createCodec(conf) + private var replayed = false - /** - * Return a list of paths representing log files in the given directory. - */ - private def getLogFilePaths(logDir: String, fileSystem: FileSystem): Array[Path] = { - val path = new Path(logDir) - if (!fileSystem.exists(path) || !fileSystem.getFileStatus(path).isDir) { - logWarning("Log path provided is not a valid directory: %s".format(logDir)) - return Array[Path]() - } - val logStatus = fileSystem.listStatus(path) - if (logStatus == null || !logStatus.exists(!_.isDir)) { - logWarning("Log path provided contains no log files: %s".format(logDir)) - return Array[Path]() - } - logStatus.filter(!_.isDir).map(_.getPath).sortBy(_.getName) + if (logPaths.length == 0) { + logWarning("Log path provided contains no log files.") } /** * Replay each event in the order maintained in the given logs. + * This should only be called exactly once. */ - def replay(logDir: String): Boolean = { - val fileSystem = Utils.getHadoopFileSystem(new URI(logDir)) - val logPaths = getLogFilePaths(logDir, fileSystem) - if (logPaths.length == 0) { - return false - } - + def replay() { + assert(!replayed, "ReplayListenerBus cannot replay events more than once") logPaths.foreach { path => // Keep track of input streams at all levels to close them later // This is necessary because an exception can occur in between stream initializations var fileStream: Option[InputStream] = None var bufferedStream: Option[InputStream] = None var compressStream: Option[InputStream] = None - var currentLine = "" + var currentLine = "" try { - currentLine = "" fileStream = Some(fileSystem.open(path)) - bufferedStream = Some(new FastBufferedInputStream(fileStream.get)) - compressStream = - if (compressed) { - Some(compressionCodec.compressedInputStream(bufferedStream.get)) - } else bufferedStream + bufferedStream = Some(new BufferedInputStream(fileStream.get)) + compressStream = Some(wrapForCompression(bufferedStream.get)) - // Parse each line as an event and post it to all attached listeners + // Parse each line as an event and post the event to all attached listeners val lines = Source.fromInputStream(compressStream.get).getLines() lines.foreach { line => currentLine = line @@ -98,7 +80,11 @@ private[spark] class ReplayListenerBus(conf: SparkConf) extends SparkListenerBus compressStream.foreach(_.close()) } } - fileSystem.close() - true + replayed = true + } + + /** If a compression codec is specified, wrap the given stream in a compression stream. */ + private def wrapForCompression(stream: InputStream): InputStream = { + compressionCodec.map(_.compressedInputStream(stream)).getOrElse(stream) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 3fc6cc9850feb..0b381308b61ff 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -20,21 +20,17 @@ package org.apache.spark.scheduler import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} +import scala.collection.mutable.HashMap + import org.apache.spark._ -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDDCheckpointData -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} +import org.apache.spark.rdd.{RDD, RDDCheckpointData} private[spark] object ResultTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - // TODO: This object shouldn't have global variables - val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf) + private val serializedInfoCache = new HashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { @@ -58,7 +54,6 @@ private[spark] object ResultTask { def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = { - val loader = Thread.currentThread.getContextClassLoader val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) val ser = SparkEnv.get.closureSerializer.newInstance() val objIn = ser.deserializeStream(in) @@ -67,6 +62,10 @@ private[spark] object ResultTask { (rdd, func) } + def removeStage(stageId: Int) { + serializedInfoCache.remove(stageId) + } + def clearCache() { synchronized { serializedInfoCache.clear() diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index e4eced383c3a5..6c5827f75e636 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -23,6 +23,7 @@ import java.util.{NoSuchElementException, Properties} import scala.xml.XML import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils /** * An interface to build Schedulable tree @@ -72,7 +73,7 @@ private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) schedulerAllocFile.map { f => new FileInputStream(f) }.getOrElse { - getClass.getClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE) + Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_SCHEDULER_FILE) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 2a9edf4a76b97..23f3b3e824762 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -24,22 +24,16 @@ import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.RDDCheckpointData +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap} private[spark] object ShuffleMapTask { // A simple map between the stage id to the serialized byte array of a task. // Served as a cache for task serialization because serialization can be // expensive on the master node if it needs to launch thousands of tasks. - val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - - // TODO: This object shouldn't have global variables - val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf) + private val serializedInfoCache = new HashMap[Int, Array[Byte]] def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { @@ -80,6 +74,10 @@ private[spark] object ShuffleMapTask { HashMap(set.toSeq: _*) } + def removeStage(stageId: Int) { + serializedInfoCache.remove(stageId) + } + def clearCache() { synchronized { serializedInfoCache.clear() diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index d4eb0ac88d8e8..378cf1aaebe7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -23,21 +23,28 @@ import scala.collection.Map import scala.collection.mutable import org.apache.spark.{Logging, TaskEndReason} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{Distribution, Utils} +@DeveloperApi sealed trait SparkListenerEvent +@DeveloperApi case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Properties = null) extends SparkListenerEvent +@DeveloperApi case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent +@DeveloperApi case class SparkListenerTaskStart(stageId: Int, taskInfo: TaskInfo) extends SparkListenerEvent +@DeveloperApi case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListenerEvent +@DeveloperApi case class SparkListenerTaskEnd( stageId: Int, taskType: String, @@ -46,32 +53,46 @@ case class SparkListenerTaskEnd( taskMetrics: TaskMetrics) extends SparkListenerEvent +@DeveloperApi case class SparkListenerJobStart(jobId: Int, stageIds: Seq[Int], properties: Properties = null) extends SparkListenerEvent +@DeveloperApi case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent +@DeveloperApi case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(String, String)]]) extends SparkListenerEvent +@DeveloperApi case class SparkListenerBlockManagerAdded(blockManagerId: BlockManagerId, maxMem: Long) extends SparkListenerEvent +@DeveloperApi case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId) extends SparkListenerEvent +@DeveloperApi case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent +case class SparkListenerApplicationStart(appName: String, time: Long, sparkUser: String) + extends SparkListenerEvent + +case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent + /** An event used in the listener to shutdown the listener daemon thread. */ private[spark] case object SparkListenerShutdown extends SparkListenerEvent /** - * Interface for listening to events from the Spark scheduler. + * :: DeveloperApi :: + * Interface for listening to events from the Spark scheduler. Note that this is an internal + * interface which might change in different Spark releases. */ +@DeveloperApi trait SparkListener { /** - * Called when a stage is completed, with information on the completed stage + * Called when a stage completes successfully or fails, with information on the completed stage. */ def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { } @@ -125,11 +146,23 @@ trait SparkListener { * Called when an RDD is manually unpersisted by the application */ def onUnpersistRDD(unpersistRDD: SparkListenerUnpersistRDD) { } + + /** + * Called when the application starts + */ + def onApplicationStart(applicationStart: SparkListenerApplicationStart) { } + + /** + * Called when the application ends + */ + def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { } } /** + * :: DeveloperApi :: * Simple SparkListener that logs a few summary statistics when each stage completes */ +@DeveloperApi class StatsReportListener extends SparkListener with Logging { import org.apache.spark.scheduler.StatsReportListener._ diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index 729e120497571..d6df193d9bcf8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -61,6 +61,10 @@ private[spark] trait SparkListenerBus { sparkListeners.foreach(_.onBlockManagerRemoved(blockManagerRemoved)) case unpersistRDD: SparkListenerUnpersistRDD => sparkListeners.foreach(_.onUnpersistRDD(unpersistRDD)) + case applicationStart: SparkListenerApplicationStart => + sparkListeners.foreach(_.onApplicationStart(applicationStart)) + case applicationEnd: SparkListenerApplicationEnd => + sparkListeners.foreach(_.onApplicationEnd(applicationEnd)) case SparkListenerShutdown => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala index 5b40a3eb29b30..b85eabd6bbdbc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala @@ -19,8 +19,11 @@ package org.apache.spark.scheduler import collection.mutable.ArrayBuffer +import org.apache.spark.annotation.DeveloperApi + // information about a specific split instance : handles both split instances. // So that we do not need to worry about the differences. +@DeveloperApi class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String, val length: Long, val underlyingSplit: Any) { override def toString(): String = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 8115a7ed7896d..9f732f7191465 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -17,17 +17,28 @@ package org.apache.spark.scheduler +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.storage.RDDInfo /** + * :: DeveloperApi :: * Stores information about a stage to pass from the scheduler to SparkListeners. */ -private[spark] +@DeveloperApi class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfo: RDDInfo) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None + /** Time when all tasks in the stage completed or when the stage was cancelled. */ var completionTime: Option[Long] = None + /** If the stage failed, the reason why. */ + var failureReason: Option[String] = None + var emittedTaskSizeWarning = false + + def stageFailed(reason: String) { + failureReason = Some(reason) + completionTime = Some(System.currentTimeMillis) + } } private[spark] diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index b85b4a50cd93a..a8bcb7dfe2f3c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,13 +17,11 @@ package org.apache.spark.scheduler -import java.io.{DataInputStream, DataOutputStream} +import java.io.{ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - import org.apache.spark.TaskContext import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance @@ -104,7 +102,7 @@ private[spark] object Task { serializer: SerializerInstance) : ByteBuffer = { - val out = new FastByteArrayOutputStream(4096) + val out = new ByteArrayOutputStream(4096) val dataOut = new DataOutputStream(out) // Write currentFiles @@ -125,8 +123,7 @@ private[spark] object Task { dataOut.flush() val taskBytes = serializer.serialize(task).array() out.write(taskBytes) - out.trim() - ByteBuffer.wrap(out.array) + ByteBuffer.wrap(out.toByteArray) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 6183b125def99..4c62e4dc0bac8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -17,10 +17,13 @@ package org.apache.spark.scheduler +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * Information about a running task attempt inside a TaskSet. */ -private[spark] +@DeveloperApi class TaskInfo( val taskId: Long, val index: Int, @@ -46,15 +49,15 @@ class TaskInfo( var serializedSize: Int = 0 - def markGettingResult(time: Long = System.currentTimeMillis) { + private[spark] def markGettingResult(time: Long = System.currentTimeMillis) { gettingResultTime = time } - def markSuccessful(time: Long = System.currentTimeMillis) { + private[spark] def markSuccessful(time: Long = System.currentTimeMillis) { finishTime = time } - def markFailed(time: Long = System.currentTimeMillis) { + private[spark] def markFailed(time: Long = System.currentTimeMillis) { finishTime = time failed = true } @@ -83,11 +86,11 @@ class TaskInfo( def duration: Long = { if (!finished) { - throw new UnsupportedOperationException("duration() called on unfinished tasks") + throw new UnsupportedOperationException("duration() called on unfinished task") } else { finishTime - launchTime } } - def timeRunning(currentTime: Long): Long = currentTime - launchTime + private[spark] def timeRunning(currentTime: Long): Long = currentTime - launchTime } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala index 308edb12edd5c..eb920ab0c0b67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala @@ -17,7 +17,10 @@ package org.apache.spark.scheduler -private[spark] object TaskLocality extends Enumeration { +import org.apache.spark.annotation.DeveloperApi + +@DeveloperApi +object TaskLocality extends Enumeration { // Process local is expected to be used ONLY within TaskSetManager for now. val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index cb4ad4ae9350c..c9ad2b151daf0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -85,13 +85,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul try { if (serializedData != null && serializedData.limit() > 0) { reason = serializer.get().deserialize[TaskEndReason]( - serializedData, getClass.getClassLoader) + serializedData, Utils.getSparkClassLoader) } } catch { case cnd: ClassNotFoundException => // Log an error but keep going here -- the task failed, so not catastropic if we can't // deserialize the reason. - val loader = Thread.currentThread.getContextClassLoader + val loader = Utils.getContextOrSparkClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) case ex: Throwable => {} diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a92922166f595..acd152dda89d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -42,7 +42,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * * THREADING: SchedulerBackends and task-submitting clients can call this class from multiple * threads, so it needs locks in public API methods to maintain its state. In addition, some - * SchedulerBackends sycnchronize on themselves when they want to send events here, and then + * SchedulerBackends synchronize on themselves when they want to send events here, and then * acquire a lock on us, so we need to make sure that we don't try to lock the backend while * we are holding a lock on ourselves. */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 25b7472a99cdb..936e9db80573d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -49,7 +49,7 @@ private[spark] class SparkDeploySchedulerBackend( "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome() val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, - sparkHome, sc.ui.appUIAddress, sc.eventLoggingInfo) + sparkHome, sc.ui.appUIAddress, sc.eventLogger.map(_.logDir)) client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) client.start() diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 18a68b05fa853..e9163deaf2036 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -21,7 +21,9 @@ import java.io._ import java.nio.ByteBuffer import org.apache.spark.SparkConf +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.Utils private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int) extends SerializationStream { @@ -85,7 +87,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize } def deserializeStream(s: InputStream): DeserializationStream = { - new JavaDeserializationStream(s, Thread.currentThread.getContextClassLoader) + new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader) } def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = { @@ -94,8 +96,14 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize } /** + * :: DeveloperApi :: * A Spark serializer that uses Java's built-in serialization. + * + * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * Spark. It is intended to be used to serialize/de-serialize data within a single + * Spark application. */ +@DeveloperApi class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 10000) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 926e71573be32..d1e8c3ef63622 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -33,6 +33,10 @@ import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock} /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. + * + * Note that this serializer is not guaranteed to be wire-compatible across different versions of + * Spark. It is intended to be used to serialize/de-serialize data within a single + * Spark application. */ class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 099143494b851..f2c8f9b6218d6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -17,25 +17,30 @@ package org.apache.spark.serializer -import java.io.{EOFException, InputStream, OutputStream} +import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream} import java.nio.ByteBuffer -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - -import org.apache.spark.util.{ByteBufferInputStream, NextIterator} import org.apache.spark.SparkEnv +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.{ByteBufferInputStream, NextIterator} /** + * :: DeveloperApi :: * A serializer. Because some serialization libraries are not thread safe, this class is used to * create [[org.apache.spark.serializer.SerializerInstance]] objects that do the actual * serialization and are guaranteed to only be called from one thread at a time. * * Implementations of this trait should implement: + * * 1. a zero-arg constructor or a constructor that accepts a [[org.apache.spark.SparkConf]] * as parameter. If both constructors are defined, the latter takes precedence. * * 2. Java serialization interface. + * + * Note that serializers are not required to be wire-compatible across different versions of Spark. + * They are intended to be used to serialize/de-serialize data within a single Spark application. */ +@DeveloperApi trait Serializer { def newInstance(): SerializerInstance } @@ -49,8 +54,10 @@ object Serializer { /** + * :: DeveloperApi :: * An instance of a serializer, for use by one thread at a time. */ +@DeveloperApi trait SerializerInstance { def serialize[T](t: T): ByteBuffer @@ -64,10 +71,9 @@ trait SerializerInstance { def serializeMany[T](iterator: Iterator[T]): ByteBuffer = { // Default implementation uses serializeStream - val stream = new FastByteArrayOutputStream() + val stream = new ByteArrayOutputStream() serializeStream(stream).writeAll(iterator) - val buffer = ByteBuffer.allocate(stream.position.toInt) - buffer.put(stream.array, 0, stream.position.toInt) + val buffer = ByteBuffer.wrap(stream.toByteArray) buffer.flip() buffer } @@ -81,8 +87,10 @@ trait SerializerInstance { /** + * :: DeveloperApi :: * A stream for writing serialized objects. */ +@DeveloperApi trait SerializationStream { def writeObject[T](t: T): SerializationStream def flush(): Unit @@ -98,8 +106,10 @@ trait SerializationStream { /** + * :: DeveloperApi :: * A stream for reading serialized objects. */ +@DeveloperApi trait DeserializationStream { def readObject[T](): T def close(): Unit diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index 2fbbda5b76c74..ace9cd51c96b7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -240,7 +240,7 @@ object BlockFetcherIterator { override def numRemoteBlocks: Int = numRemote override def fetchWaitTime: Long = _fetchWaitTime override def remoteBytesRead: Long = _remoteBytesRead - + // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue // as they arrive. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 301d784b350a3..cffea28fbf794 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId { def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD = isInstanceOf[RDDBlockId] def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] override def toString = name override def hashCode = name.hashCode @@ -48,18 +48,13 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI def name = "rdd_" + rddId + "_" + splitIndex } -private[spark] -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) + extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } -private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { - def name = "broadcast_" + broadcastId -} - -private[spark] -case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { - def name = broadcastId.name + "_" + hType +private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { + def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { @@ -83,8 +78,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId { private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r - val BROADCAST = "broadcast_([0-9]+)".r - val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -95,10 +89,8 @@ private[spark] object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) - case BROADCAST(broadcastId) => - BroadcastBlockId(broadcastId.toLong) - case BROADCAST_HELPER(broadcastId, hType) => - BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) + case BROADCAST(broadcastId, field) => + BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 71584b6eb102a..f14017051fa07 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{File, InputStream, OutputStream} +import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream} import java.nio.{ByteBuffer, MappedByteBuffer} import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -26,20 +26,19 @@ import scala.concurrent.duration._ import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} -import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} +import org.apache.spark.{Logging, MapOutputTracker, SecurityManager, SparkConf, SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer import org.apache.spark.util._ -sealed trait Values +private[spark] sealed trait Values -case class ByteBufferValues(buffer: ByteBuffer) extends Values -case class IteratorValues(iterator: Iterator[Any]) extends Values -case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends Values +private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends Values +private[spark] case class IteratorValues(iterator: Iterator[Any]) extends Values +private[spark] case class ArrayBufferValues(buffer: ArrayBuffer[Any]) extends Values private[spark] class BlockManager( executorId: String, @@ -48,7 +47,8 @@ private[spark] class BlockManager( val defaultSerializer: Serializer, maxMemory: Long, val conf: SparkConf, - securityManager: SecurityManager) + securityManager: SecurityManager, + mapOutputTracker: MapOutputTracker) extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) @@ -57,8 +57,19 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) + var tachyonInitialized = false + private[storage] lazy val tachyonStore: TachyonStore = { + val storeDir = conf.get("spark.tachyonStore.baseDir", "/tmp_spark_tachyon") + val appFolderName = conf.get("spark.tachyonStore.folderName") + val tachyonStorePath = s"${storeDir}/${appFolderName}/${this.executorId}" + val tachyonMaster = conf.get("spark.tachyonStore.url", "tachyon://localhost:19998") + val tachyonBlockManager = new TachyonBlockManager( + shuffleBlockManager, tachyonStorePath, tachyonMaster) + tachyonInitialized = true + new TachyonStore(this, tachyonBlockManager) + } // If we use Netty for shuffle, start a new Netty-based shuffle sender service. private val nettyPort: Int = { @@ -89,7 +100,7 @@ private[spark] class BlockManager( val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), + val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this, mapOutputTracker)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) // Pending re-registration action being executed asynchronously or null if none @@ -128,9 +139,10 @@ private[spark] class BlockManager( master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, - securityManager: SecurityManager) = { + securityManager: SecurityManager, + mapOutputTracker: MapOutputTracker) = { this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), - conf, securityManager) + conf, securityManager, mapOutputTracker) } /** @@ -208,9 +220,26 @@ private[spark] class BlockManager( } /** - * Get storage level of local block. If no info exists for the block, then returns null. + * Get the BlockStatus for the block identified by the given ID, if it exists. + * NOTE: This is mainly for testing, and it doesn't fetch information from Tachyon. + */ + def getStatus(blockId: BlockId): Option[BlockStatus] = { + blockInfo.get(blockId).map { info => + val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L + // Assume that block is not in Tachyon + BlockStatus(info.level, memSize, diskSize, 0L) + } + } + + /** + * Get the ids of existing blocks that match the given filter. Note that this will + * query the blocks stored in the disk block manager (that the block manager + * may not know of). */ - def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = { + (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq + } /** * Tell the master about the current storage status of a block. This will send a block update @@ -248,8 +277,10 @@ private[spark] class BlockManager( if (info.tellMaster) { val storageLevel = status.storageLevel val inMemSize = Math.max(status.memSize, droppedMemorySize) + val inTachyonSize = status.tachyonSize val onDiskSize = status.diskSize - master.updateBlockInfo(blockManagerId, blockId, storageLevel, inMemSize, onDiskSize) + master.updateBlockInfo( + blockManagerId, blockId, storageLevel, inMemSize, onDiskSize, inTachyonSize) } else true } @@ -259,22 +290,24 @@ private[spark] class BlockManager( * and the updated in-memory and on-disk sizes. */ private def getCurrentBlockStatus(blockId: BlockId, info: BlockInfo): BlockStatus = { - val (newLevel, inMemSize, onDiskSize) = info.synchronized { + val (newLevel, inMemSize, onDiskSize, inTachyonSize) = info.synchronized { info.level match { case null => - (StorageLevel.NONE, 0L, 0L) + (StorageLevel.NONE, 0L, 0L, 0L) case level => val inMem = level.useMemory && memoryStore.contains(blockId) + val inTachyon = level.useOffHeap && tachyonStore.contains(blockId) val onDisk = level.useDisk && diskStore.contains(blockId) val deserialized = if (inMem) level.deserialized else false - val replication = if (inMem || onDisk) level.replication else 1 - val storageLevel = StorageLevel(onDisk, inMem, deserialized, replication) + val replication = if (inMem || inTachyon || onDisk) level.replication else 1 + val storageLevel = StorageLevel(onDisk, inMem, inTachyon, deserialized, replication) val memSize = if (inMem) memoryStore.getSize(blockId) else 0L + val tachyonSize = if (inTachyon) tachyonStore.getSize(blockId) else 0L val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L - (storageLevel, memSize, diskSize) + (storageLevel, memSize, diskSize, tachyonSize) } } - BlockStatus(newLevel, inMemSize, onDiskSize) + BlockStatus(newLevel, inMemSize, onDiskSize, inTachyonSize) } /** @@ -355,6 +388,24 @@ private[spark] class BlockManager( } } + // Look for the block in Tachyon + if (level.useOffHeap) { + logDebug("Getting block " + blockId + " from tachyon") + if (tachyonStore.contains(blockId)) { + tachyonStore.getBytes(blockId) match { + case Some(bytes) => { + if (!asValues) { + return Some(bytes) + } else { + return Some(dataDeserialize(blockId, bytes)) + } + } + case None => + logDebug("Block " + blockId + " not found in tachyon") + } + } + } + // Look for block on disk, potentially storing it back into memory if required: if (level.useDisk) { logDebug("Getting block " + blockId + " from disk") @@ -494,9 +545,8 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. - * This is currently used for writing shuffle files out. Callers should handle error - * cases. + * The Block will be appended to the File specified by filename. This is currently used for + * writing shuffle files out. Callers should handle error cases. */ def getDiskWriter( blockId: BlockId, @@ -620,6 +670,23 @@ private[spark] class BlockManager( } // Keep track of which blocks are dropped from memory res.droppedBlocks.foreach { block => updatedBlocks += block } + } else if (level.useOffHeap) { + // Save to Tachyon. + val res = data match { + case IteratorValues(iterator) => + tachyonStore.putValues(blockId, iterator, level, false) + case ArrayBufferValues(array) => + tachyonStore.putValues(blockId, array, level, false) + case ByteBufferValues(bytes) => { + bytes.rewind(); + tachyonStore.putBytes(blockId, bytes, level) + } + } + size = res.size + res.data match { + case Right(newBytes) => bytesAfterPut = newBytes + case _ => + } } else { // Save directly to disk. // Don't get back the bytes unless we replicate them. @@ -644,8 +711,8 @@ private[spark] class BlockManager( val putBlockStatus = getCurrentBlockStatus(blockId, putBlockInfo) if (putBlockStatus.storageLevel != StorageLevel.NONE) { - // Now that the block is in either the memory or disk store, let other threads read it, - // and tell the master about it. + // Now that the block is in either the memory, tachyon, or disk store, + // let other threads read it, and tell the master about it. marked = true putBlockInfo.markReady(size) if (tellMaster) { @@ -707,7 +774,8 @@ private[spark] class BlockManager( */ var cachedPeers: Seq[BlockManagerId] = null private def replicate(blockId: BlockId, data: ByteBuffer, level: StorageLevel) { - val tLevel = StorageLevel(level.useDisk, level.useMemory, level.deserialized, 1) + val tLevel = StorageLevel( + level.useDisk, level.useMemory, level.useOffHeap, level.deserialized, 1) if (cachedPeers == null) { cachedPeers = master.getPeers(blockManagerId, level.replication - 1) } @@ -814,11 +882,22 @@ private[spark] class BlockManager( * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { - // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps - // from RDD.id to blocks. + // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) - blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false)) + blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } + blocksToRemove.size + } + + /** + * Remove all blocks belonging to the given broadcast. + */ + def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { + logInfo("Removing broadcast " + broadcastId) + val blocksToRemove = blockInfo.keys.collect { + case bid @ BroadcastBlockId(`broadcastId`, _) => bid + } + blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } blocksToRemove.size } @@ -832,9 +911,10 @@ private[spark] class BlockManager( // Removals are idempotent in disk store and memory store. At worst, we get a warning. val removedFromMemory = memoryStore.remove(blockId) val removedFromDisk = diskStore.remove(blockId) - if (!removedFromMemory && !removedFromDisk) { + val removedFromTachyon = if (tachyonInitialized) tachyonStore.remove(blockId) else false + if (!removedFromMemory && !removedFromDisk && !removedFromTachyon) { logWarning("Block " + blockId + " could not be removed as it was not found in either " + - "the disk or memory store") + "the disk, memory, or tachyon store") } blockInfo.remove(blockId) if (tellMaster && info.tellMaster) { @@ -858,10 +938,10 @@ private[spark] class BlockManager( } private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { - val iterator = blockInfo.internalMap.entrySet().iterator() + val iterator = blockInfo.getEntrySet.iterator while (iterator.hasNext) { val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2) + val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) if (time < cleanupTime && shouldDrop(id)) { info.synchronized { val level = info.level @@ -871,6 +951,9 @@ private[spark] class BlockManager( if (level.useDisk) { diskStore.remove(id) } + if (level.useOffHeap) { + tachyonStore.remove(id) + } iterator.remove() logInfo("Dropped block " + id) } @@ -882,7 +965,7 @@ private[spark] class BlockManager( def shouldCompress(blockId: BlockId): Boolean = blockId match { case ShuffleBlockId(_, _, _) => compressShuffle - case BroadcastBlockId(_) => compressBroadcast + case BroadcastBlockId(_, _) => compressBroadcast case RDDBlockId(_, _) => compressRdds case TempBlockId(_) => compressShuffleSpill case _ => false @@ -908,7 +991,7 @@ private[spark] class BlockManager( outputStream: OutputStream, values: Iterator[Any], serializer: Serializer = defaultSerializer) { - val byteStream = new FastBufferedOutputStream(outputStream) + val byteStream = new BufferedOutputStream(outputStream) val ser = serializer.newInstance() ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close() } @@ -918,10 +1001,9 @@ private[spark] class BlockManager( blockId: BlockId, values: Iterator[Any], serializer: Serializer = defaultSerializer): ByteBuffer = { - val byteStream = new FastByteArrayOutputStream(4096) + val byteStream = new ByteArrayOutputStream(4096) dataSerializeStream(blockId, byteStream, values, serializer) - byteStream.trim() - ByteBuffer.wrap(byteStream.array) + ByteBuffer.wrap(byteStream.toByteArray) } /** @@ -946,6 +1028,9 @@ private[spark] class BlockManager( blockInfo.clear() memoryStore.clear() diskStore.clear() + if (tachyonInitialized) { + tachyonStore.clear() + } metadataCleaner.cancel() broadcastCleaner.cancel() logInfo("BlockManager stopped") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index ed6937851b836..7897fade2df2b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -63,9 +63,10 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long): Boolean = { + diskSize: Long, + tachyonSize: Long): Boolean = { val res = askDriverWithReply[Boolean]( - UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize)) + UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize)) logInfo("Updated info of block " + blockId) res } @@ -80,6 +81,14 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply[Seq[Seq[BlockManagerId]]](GetLocationsMultipleBlockIds(blockIds)) } + /** + * Check if block manager master has a block. Note that this can be used to check for only + * those blocks that are reported to block manager master. + */ + def contains(blockId: BlockId) = { + !getLocations(blockId).isEmpty + } + /** Get ids of other nodes in the cluster from the driver */ def getPeers(blockManagerId: BlockManagerId, numPeers: Int): Seq[BlockManagerId] = { val result = askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId, numPeers)) @@ -98,12 +107,10 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply(RemoveBlock(blockId)) } - /** - * Remove all blocks belonging to the given RDD. - */ + /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) - future onFailure { + future.onFailure { case e: Throwable => logError("Failed to remove RDD " + rddId, e) } if (blocking) { @@ -111,6 +118,31 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } } + /** Remove all blocks belonging to the given shuffle. */ + def removeShuffle(shuffleId: Int, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) + future.onFailure { + case e: Throwable => logError("Failed to remove shuffle " + shuffleId, e) + } + if (blocking) { + Await.result(future, timeout) + } + } + + /** Remove all blocks belonging to the given broadcast. */ + def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean, blocking: Boolean) { + val future = askDriverWithReply[Future[Seq[Int]]]( + RemoveBroadcast(broadcastId, removeFromMaster)) + future.onFailure { + case e: Throwable => + logError("Failed to remove broadcast " + broadcastId + + " with removeFromMaster = " + removeFromMaster, e) + } + if (blocking) { + Await.result(future, timeout) + } + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum @@ -125,6 +157,51 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply[Array[StorageStatus]](GetStorageStatus) } + /** + * Return the block's status on all block managers, if any. NOTE: This is a + * potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, this invokes the master to query each block manager for the most + * updated block statuses. This is useful when the master is not informed of the given block + * by all block managers. + */ + def getBlockStatus( + blockId: BlockId, + askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { + val msg = GetBlockStatus(blockId, askSlaves) + /* + * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * should not block on waiting for a block manager, which can in turn be waiting for the + * master actor for a response to a prior message. + */ + val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val (blockManagerIds, futures) = response.unzip + val result = Await.result(Future.sequence(futures), timeout) + if (result == null) { + throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) + } + val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] + blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => + status.map { s => (blockManagerId, s) } + }.toMap + } + + /** + * Return a list of ids of existing blocks such that the ids match the given filter. NOTE: This + * is a potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, this invokes the master to query each block manager for the most + * updated block statuses. This is useful when the master is not informed of the given block + * by all block managers. + */ + def getMatchingBlockIds( + filter: BlockId => Boolean, + askSlaves: Boolean): Seq[BlockId] = { + val msg = GetMatchingBlockIds(filter, askSlaves) + val future = askDriverWithReply[Future[Seq[BlockId]]](msg) + Await.result(future, timeout) + } + /** Stop the driver actor, called only on the Spark driver node */ def stop() { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index ff2652b640272..c57b6e8391b13 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -73,10 +73,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus register(blockManagerId, maxMemSize, slaveActor) sender ! true - case UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + case UpdateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => // TODO: Ideally we want to handle all the message replies in receive instead of in the // individual private methods. - updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) + updateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) case GetLocations(blockId) => sender ! getLocations(blockId) @@ -93,9 +94,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetStorageStatus => sender ! storageStatus + case GetBlockStatus(blockId, askSlaves) => + sender ! blockStatus(blockId, askSlaves) + + case GetMatchingBlockIds(filter, askSlaves) => + sender ! getMatchingBlockIds(filter, askSlaves) + case RemoveRdd(rddId) => sender ! removeRdd(rddId) + case RemoveShuffle(shuffleId) => + sender ! removeShuffle(shuffleId) + + case RemoveBroadcast(broadcastId, removeFromDriver) => + sender ! removeBroadcast(broadcastId, removeFromDriver) + case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) sender ! true @@ -139,9 +152,41 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // The dispatcher is used as an implicit argument into the Future sequence construction. import context.dispatcher val removeMsg = RemoveRdd(rddId) - Future.sequence(blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] - }.toSeq) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq + ) + } + + private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { + // Nothing to do in the BlockManagerMasterActor data structures + import context.dispatcher + val removeMsg = RemoveShuffle(shuffleId) + Future.sequence( + blockManagerInfo.values.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + }.toSeq + ) + } + + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { + // TODO: Consolidate usages of + import context.dispatcher + val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) + val requiredBlockManagers = blockManagerInfo.values.filter { info => + removeFromDriver || info.blockManagerId.executorId != "" + } + Future.sequence( + requiredBlockManagers.map { bm => + bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + }.toSeq + ) } private def removeBlockManager(blockManagerId: BlockManagerId) { @@ -224,6 +269,61 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toArray } + /** + * Return the block's status for all block managers, if any. NOTE: This is a + * potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def blockStatus( + blockId: BlockId, + askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { + import context.dispatcher + val getBlockStatus = GetBlockStatus(blockId) + /* + * Rather than blocking on the block status query, master actor should simply return + * Futures to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master actor's response to a previous message. + */ + blockManagerInfo.values.map { info => + val blockStatusFuture = + if (askSlaves) { + info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + } else { + Future { info.getStatus(blockId) } + } + (info.blockManagerId, blockStatusFuture) + }.toMap + } + + /** + * Return the ids of blocks present in all the block managers that match the given filter. + * NOTE: This is a potentially expensive operation and should only be used for testing. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def getMatchingBlockIds( + filter: BlockId => Boolean, + askSlaves: Boolean): Future[Seq[BlockId]] = { + import context.dispatcher + val getMatchingBlockIds = GetMatchingBlockIds(filter) + Future.sequence( + blockManagerInfo.values.map { info => + val future = + if (askSlaves) { + info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + } else { + Future { info.blocks.keys.filter(filter).toSeq } + } + future + } + ).map(_.flatten.toSeq) + } + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -246,7 +346,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long) { + diskSize: Long, + tachyonSize: Long) { if (!blockManagerInfo.contains(blockManagerId)) { if (blockManagerId.executorId == "" && !isLocal) { @@ -265,7 +366,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus return } - blockManagerInfo(blockManagerId).updateBlockInfo(blockId, storageLevel, memSize, diskSize) + blockManagerInfo(blockManagerId).updateBlockInfo( + blockId, storageLevel, memSize, diskSize, tachyonSize) var locations: mutable.HashSet[BlockManagerId] = null if (blockLocations.containsKey(blockId)) { @@ -309,8 +411,11 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } } - -private[spark] case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) +private[spark] case class BlockStatus( + storageLevel: StorageLevel, + memSize: Long, + diskSize: Long, + tachyonSize: Long) private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, @@ -328,6 +433,8 @@ private[spark] class BlockManagerInfo( logInfo("Registering block manager %s with %s RAM".format( blockManagerId.hostPort, Utils.bytesToString(maxMem))) + def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) + def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() } @@ -336,7 +443,8 @@ private[spark] class BlockManagerInfo( blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long) { + diskSize: Long, + tachyonSize: Long) { updateLastSeenMs() @@ -350,23 +458,29 @@ private[spark] class BlockManagerInfo( } if (storageLevel.isValid) { - /* isValid means it is either stored in-memory or on-disk. + /* isValid means it is either stored in-memory, on-disk or on-Tachyon. * But the memSize here indicates the data size in or dropped from memory, + * tachyonSize here indicates the data size in or dropped from Tachyon, * and the diskSize here indicates the data size in or dropped to disk. * They can be both larger than 0, when a block is dropped from memory to disk. * Therefore, a safe way to set BlockStatus is to set its info in accurate modes. */ if (storageLevel.useMemory) { - _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0)) + _blocks.put(blockId, BlockStatus(storageLevel, memSize, 0, 0)) _remainingMem -= memSize logInfo("Added %s in memory on %s (size: %s, free: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), Utils.bytesToString(_remainingMem))) } if (storageLevel.useDisk) { - _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize)) + _blocks.put(blockId, BlockStatus(storageLevel, 0, diskSize, 0)) logInfo("Added %s on disk on %s (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) } + if (storageLevel.useOffHeap) { + _blocks.put(blockId, BlockStatus(storageLevel, 0, 0, tachyonSize)) + logInfo("Added %s on tachyon on %s (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(tachyonSize))) + } } else if (_blocks.containsKey(blockId)) { // If isValid is not true, drop the block. val blockStatus: BlockStatus = _blocks.get(blockId) @@ -381,6 +495,10 @@ private[spark] class BlockManagerInfo( logInfo("Removed %s on %s on disk (size: %s)".format( blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) } + if (blockStatus.storageLevel.useOffHeap) { + logInfo("Removed %s on %s on tachyon (size: %s)".format( + blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.tachyonSize))) + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index bbb9529b5a0ca..2b53bf33b5fba 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -34,6 +34,13 @@ private[storage] object BlockManagerMessages { // Remove all blocks belonging to a specific RDD. case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave + // Remove all blocks belonging to a specific shuffle. + case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave + + // Remove all blocks belonging to a specific broadcast. + case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) + extends ToBlockManagerSlave + ////////////////////////////////////////////////////////////////////////////////// // Messages from slaves to the master. @@ -53,11 +60,12 @@ private[storage] object BlockManagerMessages { var blockId: BlockId, var storageLevel: StorageLevel, var memSize: Long, - var diskSize: Long) + var diskSize: Long, + var tachyonSize: Long) extends ToBlockManagerMaster with Externalizable { - def this() = this(null, null, null, 0, 0) // For deserialization only + def this() = this(null, null, null, 0, 0, 0) // For deserialization only override def writeExternal(out: ObjectOutput) { blockManagerId.writeExternal(out) @@ -65,6 +73,7 @@ private[storage] object BlockManagerMessages { storageLevel.writeExternal(out) out.writeLong(memSize) out.writeLong(diskSize) + out.writeLong(tachyonSize) } override def readExternal(in: ObjectInput) { @@ -73,21 +82,25 @@ private[storage] object BlockManagerMessages { storageLevel = StorageLevel(in) memSize = in.readLong() diskSize = in.readLong() + tachyonSize = in.readLong() } } object UpdateBlockInfo { - def apply(blockManagerId: BlockManagerId, + def apply( + blockManagerId: BlockManagerId, blockId: BlockId, storageLevel: StorageLevel, memSize: Long, - diskSize: Long): UpdateBlockInfo = { - new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize) + diskSize: Long, + tachyonSize: Long): UpdateBlockInfo = { + new UpdateBlockInfo(blockManagerId, blockId, storageLevel, memSize, diskSize, tachyonSize) } // For pattern-matching - def unapply(h: UpdateBlockInfo): Option[(BlockManagerId, BlockId, StorageLevel, Long, Long)] = { - Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize)) + def unapply(h: UpdateBlockInfo) + : Option[(BlockManagerId, BlockId, StorageLevel, Long, Long, Long)] = { + Some((h.blockManagerId, h.blockId, h.storageLevel, h.memSize, h.diskSize, h.tachyonSize)) } } @@ -103,7 +116,13 @@ private[storage] object BlockManagerMessages { case object GetMemoryStatus extends ToBlockManagerMaster - case object ExpireDeadHosts extends ToBlockManagerMaster - case object GetStorageStatus extends ToBlockManagerMaster + + case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true) + extends ToBlockManagerMaster + + case class GetMatchingBlockIds(filter: BlockId => Boolean, askSlaves: Boolean = true) + extends ToBlockManagerMaster + + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index bcfb82d3c7336..6d4db064dff58 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -17,8 +17,11 @@ package org.apache.spark.storage -import akka.actor.Actor +import scala.concurrent.Future +import akka.actor.{ActorRef, Actor} + +import org.apache.spark.{Logging, MapOutputTracker} import org.apache.spark.storage.BlockManagerMessages._ /** @@ -26,14 +29,59 @@ import org.apache.spark.storage.BlockManagerMessages._ * this is used to remove blocks from the slave's BlockManager. */ private[storage] -class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor { - override def receive = { +class BlockManagerSlaveActor( + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + extends Actor with Logging { + + import context.dispatcher + // Operations that involve removing blocks may be slow and should be done asynchronously + override def receive = { case RemoveBlock(blockId) => - blockManager.removeBlock(blockId) + doAsync[Boolean]("removing block " + blockId, sender) { + blockManager.removeBlock(blockId) + true + } case RemoveRdd(rddId) => - val numBlocksRemoved = blockManager.removeRdd(rddId) - sender ! numBlocksRemoved + doAsync[Int]("removing RDD " + rddId, sender) { + blockManager.removeRdd(rddId) + } + + case RemoveShuffle(shuffleId) => + doAsync[Boolean]("removing shuffle " + shuffleId, sender) { + if (mapOutputTracker != null) { + mapOutputTracker.unregisterShuffle(shuffleId) + } + blockManager.shuffleBlockManager.removeShuffle(shuffleId) + } + + case RemoveBroadcast(broadcastId, tellMaster) => + doAsync[Int]("removing broadcast " + broadcastId, sender) { + blockManager.removeBroadcast(broadcastId, tellMaster) + } + + case GetBlockStatus(blockId, _) => + sender ! blockManager.getStatus(blockId) + + case GetMatchingBlockIds(filter, _) => + sender ! blockManager.getMatchingBlockIds(filter) + } + + private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + val future = Future { + logDebug(actionMessage) + body + } + future.onSuccess { case response => + logDebug("Done " + actionMessage + ", response is " + response) + responseActor ! response + logDebug("Sent response: " + response + " to " + responseActor) + } + future.onFailure { case t: Throwable => + logError("Error in " + actionMessage, t) + responseActor ! null.asInstanceOf[T] + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala index 7168ae18c2615..337b45b727dec 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessage.scala @@ -37,7 +37,7 @@ private[spark] class BlockMessage() { private var id: BlockId = null private var data: ByteBuffer = null private var level: StorageLevel = null - + def set(getBlock: GetBlock) { typ = BlockMessage.TYPE_GET_BLOCK id = getBlock.id @@ -75,13 +75,13 @@ private[spark] class BlockMessage() { idBuilder += buffer.getChar() } id = BlockId(idBuilder.toString) - + if (typ == BlockMessage.TYPE_PUT_BLOCK) { val booleanInt = buffer.getInt() val replication = buffer.getInt() level = StorageLevel(booleanInt, replication) - + val dataLength = buffer.getInt() data = ByteBuffer.allocate(dataLength) if (dataLength != buffer.remaining) { @@ -108,12 +108,12 @@ private[spark] class BlockMessage() { buffer.clear() set(buffer) } - + def getType: Int = typ def getId: BlockId = id def getData: ByteBuffer = data def getLevel: StorageLevel = level - + def toBufferMessage: BufferMessage = { val startTime = System.currentTimeMillis val buffers = new ArrayBuffer[ByteBuffer]() @@ -127,7 +127,7 @@ private[spark] class BlockMessage() { buffer = ByteBuffer.allocate(8).putInt(level.toInt).putInt(level.replication) buffer.flip() buffers += buffer - + buffer = ByteBuffer.allocate(4).putInt(data.remaining) buffer.flip() buffers += buffer @@ -140,7 +140,7 @@ private[spark] class BlockMessage() { buffers += data } - + /* println() println("BlockMessage: ") @@ -158,7 +158,7 @@ private[spark] class BlockMessage() { } override def toString: String = { - "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + + "BlockMessage [type = " + typ + ", id = " + id + ", level = " + level + ", data = " + (if (data != null) data.remaining.toString else "null") + "]" } } @@ -168,7 +168,7 @@ private[spark] object BlockMessage { val TYPE_GET_BLOCK: Int = 1 val TYPE_GOT_BLOCK: Int = 2 val TYPE_PUT_BLOCK: Int = 3 - + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessage = { val newBlockMessage = new BlockMessage() newBlockMessage.set(bufferMessage) @@ -192,7 +192,7 @@ private[spark] object BlockMessage { newBlockMessage.set(gotBlock) newBlockMessage } - + def fromPutBlock(putBlock: PutBlock): BlockMessage = { val newBlockMessage = new BlockMessage() newBlockMessage.set(putBlock) @@ -206,7 +206,7 @@ private[spark] object BlockMessage { val bMsg = B.toBufferMessage val C = new BlockMessage() C.set(bMsg) - + println(B.getId + " " + B.getLevel) println(C.getId + " " + C.getLevel) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala index dc62b1efaa7d4..973d85c0a9b3a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockMessageArray.scala @@ -27,16 +27,16 @@ import org.apache.spark.network._ private[spark] class BlockMessageArray(var blockMessages: Seq[BlockMessage]) extends Seq[BlockMessage] with Logging { - + def this(bm: BlockMessage) = this(Array(bm)) def this() = this(null.asInstanceOf[Seq[BlockMessage]]) - def apply(i: Int) = blockMessages(i) + def apply(i: Int) = blockMessages(i) def iterator = blockMessages.iterator - def length = blockMessages.length + def length = blockMessages.length def set(bufferMessage: BufferMessage) { val startTime = System.currentTimeMillis @@ -62,15 +62,15 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) logDebug("Trying to convert buffer " + newBuffer + " to block message") val newBlockMessage = BlockMessage.fromByteBuffer(newBuffer) logDebug("Created " + newBlockMessage) - newBlockMessages += newBlockMessage + newBlockMessages += newBlockMessage buffer.position(buffer.position() + size) } val finishTime = System.currentTimeMillis logDebug("Converted block message array from buffer message in " + (finishTime - startTime) / 1000.0 + " s") - this.blockMessages = newBlockMessages + this.blockMessages = newBlockMessages } - + def toBufferMessage: BufferMessage = { val buffers = new ArrayBuffer[ByteBuffer]() @@ -83,7 +83,7 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) buffers ++= bufferMessage.buffers logDebug("Added " + bufferMessage) }) - + logDebug("Buffer list:") buffers.foreach((x: ByteBuffer) => logDebug("" + x)) /* @@ -103,13 +103,13 @@ class BlockMessageArray(var blockMessages: Seq[BlockMessage]) } private[spark] object BlockMessageArray { - + def fromBufferMessage(bufferMessage: BufferMessage): BlockMessageArray = { val newBlockMessageArray = new BlockMessageArray() newBlockMessageArray.set(bufferMessage) newBlockMessageArray } - + def main(args: Array[String]) { val blockMessages = (0 until 10).map { i => @@ -124,10 +124,10 @@ private[spark] object BlockMessageArray { } val blockMessageArray = new BlockMessageArray(blockMessages) println("Block message array created") - + val bufferMessage = blockMessageArray.toBufferMessage println("Converted to buffer message") - + val totalSize = bufferMessage.size val newBuffer = ByteBuffer.allocate(totalSize) newBuffer.clear() @@ -137,7 +137,7 @@ private[spark] object BlockMessageArray { buffer.rewind() }) newBuffer.flip - val newBufferMessage = Message.createBufferMessage(newBuffer) + val newBufferMessage = Message.createBufferMessage(newBuffer) println("Copied to new buffer message, size = " + newBufferMessage.size) val newBlockMessageArray = BlockMessageArray.fromBufferMessage(newBufferMessage) @@ -147,7 +147,7 @@ private[spark] object BlockMessageArray { case BlockMessage.TYPE_PUT_BLOCK => { val pB = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel) println(pB) - } + } case BlockMessage.TYPE_GET_BLOCK => { val gB = new GetBlock(blockMessage.getId) println(gB) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 696b930a26b9e..a2687e6be4e34 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -17,11 +17,9 @@ package org.apache.spark.storage -import java.io.{FileOutputStream, File, OutputStream} +import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} import java.nio.channels.FileChannel -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream - import org.apache.spark.Logging import org.apache.spark.serializer.{SerializationStream, Serializer} @@ -119,7 +117,7 @@ private[spark] class DiskBlockObjectWriter( ts = new TimeTrackingOutputStream(fos) channel = fos.getChannel() lastValidPosition = initialPosition - bs = compressStream(new FastBufferedOutputStream(ts, bufferSize)) + bs = compressStream(new BufferedOutputStream(ts, bufferSize)) objOut = serializer.newInstance().serializeStream(bs) initialized = true this diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f3e1c38744d78..7a24c8f57f43b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -90,6 +90,20 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) + /** Check if disk block manager has a block. */ + def containsBlock(blockId: BlockId): Boolean = { + getBlockLocation(blockId).file.exists() + } + + /** List all the blocks currently stored on disk by the disk manager. */ + def getAllBlocks(): Seq[BlockId] = { + // Get all the files inside the array of array of directories + subDirs.flatten.filter(_ != null).flatMap { dir => + val files = dir.list() + if (files != null) files else Seq.empty + }.map(BlockId.apply) + } + /** Produces a unique block id and File suitable for intermediate results. */ def createTempBlock(): (TempBlockId, File) = { var blockId = new TempBlockId(UUID.randomUUID()) diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala index 555486830a769..132502b75f8cd 100644 --- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala +++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala @@ -23,6 +23,6 @@ import java.io.File * References a particular segment of a file (potentially the entire file), * based off an offset and a length. */ -private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) { +private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) { override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index bb07c8cb134cc..4cd4cdbd9909d 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -169,23 +169,43 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { throw new IllegalStateException("Failed to find shuffle block: " + id) } + /** Remove all the blocks / files and metadata related to a particular shuffle. */ + def removeShuffle(shuffleId: ShuffleId): Boolean = { + // Do not change the ordering of this, if shuffleStates should be removed only + // after the corresponding shuffle blocks have been removed + val cleaned = removeShuffleBlocks(shuffleId) + shuffleStates.remove(shuffleId) + cleaned + } + + /** Remove all the blocks / files related to a particular shuffle. */ + private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { + shuffleStates.get(shuffleId) match { + case Some(state) => + if (consolidateShuffleFiles) { + for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { + file.delete() + } + } else { + for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { + val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) + blockManager.diskBlockManager.getFile(blockId).delete() + } + } + logInfo("Deleted all files for shuffle " + shuffleId) + true + case None => + logInfo("Could not find files for shuffle " + shuffleId + " for deleting") + false + } + } + private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = { "merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId) } private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => { - if (consolidateShuffleFiles) { - for (fileGroup <- state.allFileGroups; file <- fileGroup.files) { - file.delete() - } - } else { - for (mapId <- state.completedMapTasks; reduceId <- 0 until state.numBuckets) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - blockManager.diskBlockManager.getFile(blockId).delete() - } - } - }) + shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 4212a539dab4b..95e71de2d3f1d 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -21,8 +21,9 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} /** * Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, - * whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory - * in a serialized format, and whether to replicate the RDD partitions on multiple nodes. + * or Tachyon, whether to drop the RDD to disk if it falls out of memory or Tachyon , whether to + * keep the data in memory in a serialized format, and whether to replicate the RDD partitions on + * multiple nodes. * The [[org.apache.spark.storage.StorageLevel$]] singleton object contains some static constants * for commonly useful storage levels. To create your own storage level object, use the * factory method of the singleton object (`StorageLevel(...)`). @@ -30,45 +31,58 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} class StorageLevel private( private var useDisk_ : Boolean, private var useMemory_ : Boolean, + private var useOffHeap_ : Boolean, private var deserialized_ : Boolean, private var replication_ : Int = 1) extends Externalizable { // TODO: Also add fields for caching priority, dataset ID, and flushing. private def this(flags: Int, replication: Int) { - this((flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) + this((flags & 8) != 0, (flags & 4) != 0, (flags & 2) != 0, (flags & 1) != 0, replication) } - def this() = this(false, true, false) // For deserialization + def this() = this(false, true, false, false) // For deserialization def useDisk = useDisk_ def useMemory = useMemory_ + def useOffHeap = useOffHeap_ def deserialized = deserialized_ def replication = replication_ assert(replication < 40, "Replication restricted to be less than 40 for calculating hashcodes") + if (useOffHeap) { + require(useDisk == false, "Off-heap storage level does not support using disk") + require(useMemory == false, "Off-heap storage level does not support using heap memory") + require(deserialized == false, "Off-heap storage level does not support deserialized storage") + require(replication == 1, "Off-heap storage level does not support multiple replication") + } + override def clone(): StorageLevel = new StorageLevel( - this.useDisk, this.useMemory, this.deserialized, this.replication) + this.useDisk, this.useMemory, this.useOffHeap, this.deserialized, this.replication) override def equals(other: Any): Boolean = other match { case s: StorageLevel => s.useDisk == useDisk && s.useMemory == useMemory && + s.useOffHeap == useOffHeap && s.deserialized == deserialized && s.replication == replication case _ => false } - def isValid = ((useMemory || useDisk) && (replication > 0)) + def isValid = ((useMemory || useDisk || useOffHeap) && (replication > 0)) def toInt: Int = { var ret = 0 if (useDisk_) { - ret |= 4 + ret |= 8 } if (useMemory_) { + ret |= 4 + } + if (useOffHeap_) { ret |= 2 } if (deserialized_) { @@ -84,8 +98,9 @@ class StorageLevel private( override def readExternal(in: ObjectInput) { val flags = in.readByte() - useDisk_ = (flags & 4) != 0 - useMemory_ = (flags & 2) != 0 + useDisk_ = (flags & 8) != 0 + useMemory_ = (flags & 4) != 0 + useOffHeap_ = (flags & 2) != 0 deserialized_ = (flags & 1) != 0 replication_ = in.readByte() } @@ -93,14 +108,15 @@ class StorageLevel private( @throws(classOf[IOException]) private def readResolve(): Object = StorageLevel.getCachedStorageLevel(this) - override def toString: String = - "StorageLevel(%b, %b, %b, %d)".format(useDisk, useMemory, deserialized, replication) + override def toString: String = "StorageLevel(%b, %b, %b, %b, %d)".format( + useDisk, useMemory, useOffHeap, deserialized, replication) override def hashCode(): Int = toInt * 41 + replication def description : String = { var result = "" result += (if (useDisk) "Disk " else "") result += (if (useMemory) "Memory " else "") + result += (if (useOffHeap) "Tachyon " else "") result += (if (deserialized) "Deserialized " else "Serialized ") result += "%sx Replicated".format(replication) result @@ -113,22 +129,28 @@ class StorageLevel private( * new storage levels. */ object StorageLevel { - val NONE = new StorageLevel(false, false, false) - val DISK_ONLY = new StorageLevel(true, false, false) - val DISK_ONLY_2 = new StorageLevel(true, false, false, 2) - val MEMORY_ONLY = new StorageLevel(false, true, true) - val MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2) - val MEMORY_ONLY_SER = new StorageLevel(false, true, false) - val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2) - val MEMORY_AND_DISK = new StorageLevel(true, true, true) - val MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2) - val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false) - val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2) + val NONE = new StorageLevel(false, false, false, false) + val DISK_ONLY = new StorageLevel(true, false, false, false) + val DISK_ONLY_2 = new StorageLevel(true, false, false, false, 2) + val MEMORY_ONLY = new StorageLevel(false, true, false, true) + val MEMORY_ONLY_2 = new StorageLevel(false, true, false, true, 2) + val MEMORY_ONLY_SER = new StorageLevel(false, true, false, false) + val MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, false, 2) + val MEMORY_AND_DISK = new StorageLevel(true, true, false, true) + val MEMORY_AND_DISK_2 = new StorageLevel(true, true, false, true, 2) + val MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, false) + val MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, false, 2) + val OFF_HEAP = new StorageLevel(false, false, true, false) + + /** Create a new StorageLevel object without setting useOffHeap */ + def apply(useDisk: Boolean, useMemory: Boolean, useOffHeap: Boolean, + deserialized: Boolean, replication: Int) = getCachedStorageLevel( + new StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication)) /** Create a new StorageLevel object */ - def apply(useDisk: Boolean, useMemory: Boolean, deserialized: Boolean, - replication: Int = 1): StorageLevel = - getCachedStorageLevel(new StorageLevel(useDisk, useMemory, deserialized, replication)) + def apply(useDisk: Boolean, useMemory: Boolean, + deserialized: Boolean, replication: Int = 1) = getCachedStorageLevel( + new StorageLevel(useDisk, useMemory, false, deserialized, replication)) /** Create a new StorageLevel object from its integer representation */ def apply(flags: Int, replication: Int): StorageLevel = diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 26565f56ad858..7a174959037be 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -44,7 +44,7 @@ private[spark] class StorageStatusListener extends SparkListener { storageStatusList.foreach { storageStatus => val unpersistedBlocksIds = storageStatus.rddBlocks.keys.filter(_.rddId == unpersistedRDDId) unpersistedBlocksIds.foreach { blockId => - storageStatus.blocks(blockId) = BlockStatus(StorageLevel.NONE, 0L, 0L) + storageStatus.blocks(blockId) = BlockStatus(StorageLevel.NONE, 0L, 0L, 0L) } } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 6153dfe0b7e13..7ed371326855d 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -21,6 +21,7 @@ import scala.collection.Map import scala.collection.mutable import org.apache.spark.SparkContext +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils private[spark] @@ -41,24 +42,29 @@ class StorageStatus( def memRemaining : Long = maxMem - memUsed() - def rddBlocks = blocks.flatMap { - case (rdd: RDDBlockId, status) => Some(rdd, status) - case _ => None - } + def rddBlocks = blocks.collect { case (rdd: RDDBlockId, status) => (rdd, status) } } +@DeveloperApi private[spark] -class RDDInfo(val id: Int, val name: String, val numPartitions: Int, val storageLevel: StorageLevel) +class RDDInfo( + val id: Int, + val name: String, + val numPartitions: Int, + val storageLevel: StorageLevel) extends Ordered[RDDInfo] { var numCachedPartitions = 0 var memSize = 0L var diskSize = 0L + var tachyonSize = 0L override def toString = { - ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s; " + - "DiskSize: %s").format(name, id, storageLevel.toString, numCachedPartitions, - numPartitions, Utils.bytesToString(memSize), Utils.bytesToString(diskSize)) + import Utils.bytesToString + ("RDD \"%s\" (%d) Storage: %s; CachedPartitions: %d; TotalPartitions: %d; MemorySize: %s;" + + "TachyonSize: %s; DiskSize: %s").format( + name, id, storageLevel.toString, numCachedPartitions, numPartitions, + bytesToString(memSize), bytesToString(tachyonSize), bytesToString(diskSize)) } override def compare(that: RDDInfo) = { @@ -105,14 +111,17 @@ object StorageUtils { val rddInfoMap = rddInfos.map { info => (info.id, info) }.toMap val rddStorageInfos = blockStatusMap.flatMap { case (rddId, blocks) => - // Add up memory and disk sizes - val persistedBlocks = blocks.filter { status => status.memSize + status.diskSize > 0 } + // Add up memory, disk and Tachyon sizes + val persistedBlocks = + blocks.filter { status => status.memSize + status.diskSize + status.tachyonSize > 0 } val memSize = persistedBlocks.map(_.memSize).reduceOption(_ + _).getOrElse(0L) val diskSize = persistedBlocks.map(_.diskSize).reduceOption(_ + _).getOrElse(0L) + val tachyonSize = persistedBlocks.map(_.tachyonSize).reduceOption(_ + _).getOrElse(0L) rddInfoMap.get(rddId).map { rddInfo => rddInfo.numCachedPartitions = persistedBlocks.length rddInfo.memSize = memSize rddInfo.diskSize = diskSize + rddInfo.tachyonSize = tachyonSize rddInfo } }.toArray diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala new file mode 100644 index 0000000000000..b0b9674856568 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.text.SimpleDateFormat +import java.util.{Date, Random} + +import tachyon.client.TachyonFS +import tachyon.client.TachyonFile + +import org.apache.spark.Logging +import org.apache.spark.executor.ExecutorExitCode +import org.apache.spark.network.netty.ShuffleSender +import org.apache.spark.util.Utils + + +/** + * Creates and maintains the logical mapping between logical blocks and tachyon fs locations. By + * default, one block is mapped to one file with a name given by its BlockId. + * + * @param rootDirs The directories to use for storing block files. Data will be hashed among these. + */ +private[spark] class TachyonBlockManager( + shuffleManager: ShuffleBlockManager, + rootDirs: String, + val master: String) + extends Logging { + + val client = if (master != null && master != "") TachyonFS.get(master) else null + + if (client == null) { + logError("Failed to connect to the Tachyon as the master address is not configured") + System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_INITIALIZE) + } + + private val MAX_DIR_CREATION_ATTEMPTS = 10 + private val subDirsPerTachyonDir = + shuffleManager.conf.get("spark.tachyonStore.subDirectories", "64").toInt + + // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; + // then, inside this directory, create multiple subdirectories that we will hash files into, + // in order to avoid having really large inodes at the top level in Tachyon. + private val tachyonDirs: Array[TachyonFile] = createTachyonDirs() + private val subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) + + addShutdownHook() + + def removeFile(file: TachyonFile): Boolean = { + client.delete(file.getPath(), false) + } + + def fileExists(file: TachyonFile): Boolean = { + client.exist(file.getPath()) + } + + def getFile(filename: String): TachyonFile = { + // Figure out which tachyon directory it hashes to, and which subdirectory in that + val hash = Utils.nonNegativeHash(filename) + val dirId = hash % tachyonDirs.length + val subDirId = (hash / tachyonDirs.length) % subDirsPerTachyonDir + + // Create the subdirectory if it doesn't already exist + var subDir = subDirs(dirId)(subDirId) + if (subDir == null) { + subDir = subDirs(dirId).synchronized { + val old = subDirs(dirId)(subDirId) + if (old != null) { + old + } else { + val path = tachyonDirs(dirId) + "/" + "%02x".format(subDirId) + client.mkdir(path) + val newDir = client.getFile(path) + subDirs(dirId)(subDirId) = newDir + newDir + } + } + } + val filePath = subDir + "/" + filename + if(!client.exist(filePath)) { + client.createFile(filePath) + } + val file = client.getFile(filePath) + file + } + + def getFile(blockId: BlockId): TachyonFile = getFile(blockId.name) + + // TODO: Some of the logic here could be consolidated/de-duplicated with that in the DiskStore. + private def createTachyonDirs(): Array[TachyonFile] = { + logDebug("Creating tachyon directories at root dirs '" + rootDirs + "'") + val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss") + rootDirs.split(",").map { rootDir => + var foundLocalDir = false + var tachyonDir: TachyonFile = null + var tachyonDirId: String = null + var tries = 0 + val rand = new Random() + while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) { + tries += 1 + try { + tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) + val path = rootDir + "/" + "spark-tachyon-" + tachyonDirId + if (!client.exist(path)) { + foundLocalDir = client.mkdir(path) + tachyonDir = client.getFile(path) + } + } catch { + case e: Exception => + logWarning("Attempt " + tries + " to create tachyon dir " + tachyonDir + " failed", e) + } + } + if (!foundLocalDir) { + logError("Failed " + MAX_DIR_CREATION_ATTEMPTS + " attempts to create tachyon dir in " + + rootDir) + System.exit(ExecutorExitCode.TACHYON_STORE_FAILED_TO_CREATE_DIR) + } + logInfo("Created tachyon directory at " + tachyonDir) + tachyonDir + } + } + + private def addShutdownHook() { + tachyonDirs.foreach(tachyonDir => Utils.registerShutdownDeleteDir(tachyonDir)) + Runtime.getRuntime.addShutdownHook(new Thread("delete Spark tachyon dirs") { + override def run() { + logDebug("Shutdown hook called") + tachyonDirs.foreach { tachyonDir => + try { + if (!Utils.hasRootAsShutdownDeleteDir(tachyonDir)) { + Utils.deleteRecursively(tachyonDir, client) + } + } catch { + case t: Throwable => + logError("Exception while deleting tachyon spark dir: " + tachyonDir, t) + } + } + } + }) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala new file mode 100644 index 0000000000000..b86abbda1d3e7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/TachyonFileSegment.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import tachyon.client.TachyonFile + +/** + * References a particular segment of a file (potentially the entire file), based off an offset and + * a length. + */ +private[spark] class TachyonFileSegment(val file: TachyonFile, val offset: Long, val length: Long) { + override def toString = "(name=%s, offset=%d, length=%d)".format(file.getPath(), offset, length) +} diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala new file mode 100644 index 0000000000000..c37e76f893605 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.io.IOException +import java.nio.ByteBuffer + +import scala.collection.mutable.ArrayBuffer + +import tachyon.client.{WriteType, ReadType} + +import org.apache.spark.Logging +import org.apache.spark.util.Utils +import org.apache.spark.serializer.Serializer + + +private class Entry(val size: Long) + + +/** + * Stores BlockManager blocks on Tachyon. + */ +private class TachyonStore( + blockManager: BlockManager, + tachyonManager: TachyonBlockManager) + extends BlockStore(blockManager: BlockManager) with Logging { + + logInfo("TachyonStore started") + + override def getSize(blockId: BlockId): Long = { + tachyonManager.getFile(blockId.name).length + } + + override def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel): PutResult = { + putToTachyonStore(blockId, bytes, true) + } + + override def putValues( + blockId: BlockId, + values: ArrayBuffer[Any], + level: StorageLevel, + returnValues: Boolean): PutResult = { + return putValues(blockId, values.toIterator, level, returnValues) + } + + override def putValues( + blockId: BlockId, + values: Iterator[Any], + level: StorageLevel, + returnValues: Boolean): PutResult = { + logDebug("Attempting to write values for block " + blockId) + val _bytes = blockManager.dataSerialize(blockId, values) + putToTachyonStore(blockId, _bytes, returnValues) + } + + private def putToTachyonStore( + blockId: BlockId, + bytes: ByteBuffer, + returnValues: Boolean): PutResult = { + // So that we do not modify the input offsets ! + // duplicate does not copy buffer, so inexpensive + val byteBuffer = bytes.duplicate() + byteBuffer.rewind() + logDebug("Attempting to put block " + blockId + " into Tachyon") + val startTime = System.currentTimeMillis + val file = tachyonManager.getFile(blockId) + val os = file.getOutStream(WriteType.TRY_CACHE) + os.write(byteBuffer.array()) + os.close() + val finishTime = System.currentTimeMillis + logDebug("Block %s stored as %s file in Tachyon in %d ms".format( + blockId, Utils.bytesToString(byteBuffer.limit), (finishTime - startTime))) + + if (returnValues) { + PutResult(bytes.limit(), Right(bytes.duplicate())) + } else { + PutResult(bytes.limit(), null) + } + } + + override def remove(blockId: BlockId): Boolean = { + val file = tachyonManager.getFile(blockId) + if (tachyonManager.fileExists(file)) { + tachyonManager.removeFile(file) + } else { + false + } + } + + override def getValues(blockId: BlockId): Option[Iterator[Any]] = { + getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer)) + } + + + override def getBytes(blockId: BlockId): Option[ByteBuffer] = { + val file = tachyonManager.getFile(blockId) + if (file == null || file.getLocationHosts().size == 0) { + return None + } + val is = file.getInStream(ReadType.CACHE) + var buffer: ByteBuffer = null + try { + if (is != null) { + val size = file.length + val bs = new Array[Byte](size.asInstanceOf[Int]) + val fetchSize = is.read(bs, 0, size.asInstanceOf[Int]) + buffer = ByteBuffer.wrap(bs) + if (fetchSize != size) { + logWarning("Failed to fetch the block " + blockId + " from Tachyon : Size " + size + + " is not equal to fetched size " + fetchSize) + return None + } + } + } catch { + case ioe: IOException => { + logWarning("Failed to fetch the block " + blockId + " from Tachyon", ioe) + return None + } + } + Some(buffer) + } + + override def contains(blockId: BlockId): Boolean = { + val file = tachyonManager.getFile(blockId) + tachyonManager.fileExists(file) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 226ed2a132b00..a107c5182b3be 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ArrayBlockingQueue import akka.actor._ import util.Random -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer @@ -48,7 +48,7 @@ private[spark] object ThreadingTest { val block = (1 to blockSize).map(_ => Random.nextInt()) val level = randomLevel() val startTime = System.currentTimeMillis() - manager.put(blockId, block.iterator, level, true) + manager.put(blockId, block.iterator, level, tellMaster = true) println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") queue.add((blockId, block)) } @@ -101,7 +101,7 @@ private[spark] object ThreadingTest { conf) val blockManager = new BlockManager( "", actorSystem, blockManagerMaster, serializer, 1024 * 1024, conf, - new SecurityManager(conf)) + new SecurityManager(conf), new MapOutputTrackerMaster(conf)) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) val consumers = producers.map(p => new ConsumerThread(blockManager, p.queue)) producers.foreach(_.start) diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index e1a1f209c9282..3ae147a36c8a4 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -33,6 +33,7 @@ import org.json4s.JValue import org.json4s.jackson.JsonMethods.{pretty, render} import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.util.Utils /** * Utilities for launching a web server using Jetty's HTTP Server class @@ -104,10 +105,12 @@ private[spark] object JettyUtils extends Logging { def createRedirectHandler( srcPath: String, destPath: String, + beforeRedirect: HttpServletRequest => Unit = x => (), basePath: String = ""): ServletContextHandler = { val prefixedDestPath = attachPrefix(basePath, destPath) val servlet = new HttpServlet { override def doGet(request: HttpServletRequest, response: HttpServletResponse) { + beforeRedirect(request) // Make sure we don't end up with "//" in the middle val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString response.sendRedirect(newUrl) @@ -119,9 +122,10 @@ private[spark] object JettyUtils extends Logging { /** Create a handler for serving files from a static directory */ def createStaticHandler(resourceBase: String, path: String): ServletContextHandler = { val contextHandler = new ServletContextHandler + contextHandler.setInitParameter("org.eclipse.jetty.servlet.Default.gzip", "false") val staticHandler = new DefaultServlet val holder = new ServletHolder(staticHandler) - Option(getClass.getClassLoader.getResource(resourceBase)) match { + Option(Utils.getSparkClassLoader.getResource(resourceBase)) match { case Some(res) => holder.setInitParameter("resourceBase", res.toString) case None => @@ -136,7 +140,7 @@ private[spark] object JettyUtils extends Logging { private def addFilters(handlers: Seq[ServletContextHandler], conf: SparkConf) { val filters: Array[String] = conf.get("spark.ui.filters", "").split(',').map(_.trim()) filters.foreach { - case filter : String => + case filter : String => if (!filter.isEmpty) { logInfo("Adding filter: " + filter) val holder : FilterHolder = new FilterHolder() @@ -151,7 +155,7 @@ private[spark] object JettyUtils extends Logging { if (parts.length == 2) holder.setInitParameter(parts(0), parts(1)) } } - val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.ERROR, DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST) handlers.foreach { case(handler) => handler.addFilter(holder, "/*", enumDispatcher) } } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index ef1ad872c8ef7..2fef1a635427c 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -17,106 +17,86 @@ package org.apache.spark.ui -import org.eclipse.jetty.servlet.ServletContextHandler - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.env.EnvironmentUI -import org.apache.spark.ui.exec.ExecutorsUI -import org.apache.spark.ui.jobs.JobProgressUI -import org.apache.spark.ui.storage.BlockManagerUI -import org.apache.spark.util.Utils +import org.apache.spark.ui.env.EnvironmentTab +import org.apache.spark.ui.exec.ExecutorsTab +import org.apache.spark.ui.jobs.JobProgressTab +import org.apache.spark.ui.storage.StorageTab -/** Top level user interface for Spark */ +/** + * Top level user interface for a Spark application. + */ private[spark] class SparkUI( val sc: SparkContext, - conf: SparkConf, + val conf: SparkConf, + val securityManager: SecurityManager, val listenerBus: SparkListenerBus, - val appName: String, + var appName: String, val basePath: String = "") - extends Logging { + extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath) + with Logging { - def this(sc: SparkContext) = this(sc, sc.conf, sc.listenerBus, sc.appName) + def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName) def this(conf: SparkConf, listenerBus: SparkListenerBus, appName: String, basePath: String) = - this(null, conf, listenerBus, appName, basePath) + this(null, conf, new SecurityManager(conf), listenerBus, appName, basePath) // If SparkContext is not provided, assume the associated application is not live val live = sc != null - val securityManager = if (live) sc.env.securityManager else new SecurityManager(conf) - - private val bindHost = Utils.localHostName() - private val publicHost = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(bindHost) - private val port = conf.get("spark.ui.port", SparkUI.DEFAULT_PORT).toInt - private var serverInfo: Option[ServerInfo] = None - - private val storage = new BlockManagerUI(this) - private val jobs = new JobProgressUI(this) - private val env = new EnvironmentUI(this) - private val exec = new ExecutorsUI(this) - - val handlers: Seq[ServletContextHandler] = { - val metricsServletHandlers = if (live) { - SparkEnv.get.metricsSystem.getServletHandlers - } else { - Array[ServletContextHandler]() - } - storage.getHandlers ++ - jobs.getHandlers ++ - env.getHandlers ++ - exec.getHandlers ++ - metricsServletHandlers ++ - Seq[ServletContextHandler] ( - createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"), - createRedirectHandler("/", "/stages", basePath) - ) - } - // Maintain executor storage status through Spark events val storageStatusListener = new StorageStatusListener - /** Bind the HTTP server which backs this web interface */ - def bind() { - try { - serverInfo = Some(startJettyServer(bindHost, port, handlers, sc.conf)) - logInfo("Started Spark Web UI at http://%s:%d".format(publicHost, boundPort)) - } catch { - case e: Exception => - logError("Failed to create Spark JettyUtils", e) - System.exit(1) + initialize() + + /** Initialize all components of the server. */ + def initialize() { + listenerBus.addListener(storageStatusListener) + val jobProgressTab = new JobProgressTab(this) + attachTab(jobProgressTab) + attachTab(new StorageTab(this)) + attachTab(new EnvironmentTab(this)) + attachTab(new ExecutorsTab(this)) + attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static")) + attachHandler(createRedirectHandler("/", "/stages", basePath = basePath)) + attachHandler( + createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest)) + if (live) { + sc.env.metricsSystem.getServletHandlers.foreach(attachHandler) } } - def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) - - /** Initialize all components of the server */ - def start() { - storage.start() - jobs.start() - env.start() - exec.start() + /** Set the app name for this UI. */ + def setAppName(name: String) { + appName = name + } - // Storage status listener must receive events first, as other listeners depend on its state - listenerBus.addListener(storageStatusListener) - listenerBus.addListener(storage.listener) - listenerBus.addListener(jobs.listener) - listenerBus.addListener(env.listener) - listenerBus.addListener(exec.listener) + /** Register the given listener with the listener bus. */ + def registerListener(listener: SparkListener) { + listenerBus.addListener(listener) } - def stop() { - assert(serverInfo.isDefined, "Attempted to stop a SparkUI that was not bound to a server!") - serverInfo.get.server.stop() - logInfo("Stopped Spark Web UI at %s".format(appUIAddress)) + /** Stop the server behind this web interface. Only valid after bind(). */ + override def stop() { + super.stop() + logInfo("Stopped Spark web UI at %s".format(appUIAddress)) } - private[spark] def appUIAddress = "http://" + publicHost + ":" + boundPort + /** + * Return the application UI host:port. This does not include the scheme (http://). + */ + private[spark] def appUIHostPort = publicHostName + ":" + boundPort + private[spark] def appUIAddress = s"http://$appUIHostPort" } private[spark] object SparkUI { - val DEFAULT_PORT = "4040" + val DEFAULT_PORT = 4040 val STATIC_RESOURCE_DIR = "org/apache/spark/ui/static" + + def getUIPort(conf: SparkConf): Int = { + conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT) + } } diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index a487924effbff..6a2d652528d8a 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -17,16 +17,115 @@ package org.apache.spark.ui +import java.text.SimpleDateFormat +import java.util.{Locale, Date} + import scala.xml.Node +import org.apache.spark.Logging /** Utility functions for generating XML pages with spark content. */ -private[spark] object UIUtils { +private[spark] object UIUtils extends Logging { + + // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. + private val dateFormat = new ThreadLocal[SimpleDateFormat]() { + override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") + } + + def formatDate(date: Date): String = dateFormat.get.format(date) + + def formatDate(timestamp: Long): String = dateFormat.get.format(new Date(timestamp)) + + def formatDuration(milliseconds: Long): String = { + val seconds = milliseconds.toDouble / 1000 + if (seconds < 60) { + return "%.0f s".format(seconds) + } + val minutes = seconds / 60 + if (minutes < 10) { + return "%.1f min".format(minutes) + } else if (minutes < 60) { + return "%.0f min".format(minutes) + } + val hours = minutes / 60 + "%.1f h".format(hours) + } + + /** Generate a verbose human-readable string representing a duration such as "5 second 35 ms" */ + def formatDurationVerbose(ms: Long): String = { + try { + val second = 1000L + val minute = 60 * second + val hour = 60 * minute + val day = 24 * hour + val week = 7 * day + val year = 365 * day + + def toString(num: Long, unit: String): String = { + if (num == 0) { + "" + } else if (num == 1) { + s"$num $unit" + } else { + s"$num ${unit}s" + } + } + + val millisecondsString = if (ms >= second && ms % second == 0) "" else s"${ms % second} ms" + val secondString = toString((ms % minute) / second, "second") + val minuteString = toString((ms % hour) / minute, "minute") + val hourString = toString((ms % day) / hour, "hour") + val dayString = toString((ms % week) / day, "day") + val weekString = toString((ms % year) / week, "week") + val yearString = toString(ms / year, "year") - import Page._ + Seq( + second -> millisecondsString, + minute -> s"$secondString $millisecondsString", + hour -> s"$minuteString $secondString", + day -> s"$hourString $minuteString $secondString", + week -> s"$dayString $hourString $minuteString", + year -> s"$weekString $dayString $hourString" + ).foreach { case (durationLimit, durationString) => + if (ms < durationLimit) { + // if time is less than the limit (upto year) + return durationString + } + } + // if time is more than a year + return s"$yearString $weekString $dayString" + } catch { + case e: Exception => + logError("Error converting time to string", e) + // if there is some error, return blank string + return "" + } + } + + /** Generate a human-readable string representing a number (e.g. 100 K) */ + def formatNumber(records: Double): String = { + val trillion = 1e12 + val billion = 1e9 + val million = 1e6 + val thousand = 1e3 + + val (value, unit) = { + if (records >= 2*trillion) { + (records / trillion, " T") + } else if (records >= 2*billion) { + (records / billion, " B") + } else if (records >= 2*million) { + (records / million, " M") + } else if (records >= 2*thousand) { + (records / thousand, " K") + } else { + (records, "") + } + } + "%.1f%s".formatLocal(Locale.US, value, unit) + } // Yarn has to go through a proxy so the base uri is provided and has to be on all links - private[spark] val uiRoot : String = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")). - getOrElse("") + val uiRoot : String = Option(System.getenv("APPLICATION_WEB_PROXY_BASE")).getOrElse("") def prependBaseUri(basePath: String = "", resource: String = "") = uiRoot + basePath + resource @@ -36,26 +135,14 @@ private[spark] object UIUtils { basePath: String, appName: String, title: String, - page: Page.Value) : Seq[Node] = { - val jobs = page match { - case Stages => -
  • Stages
  • - case _ =>
  • Stages
  • - } - val storage = page match { - case Storage => -
  • Storage
  • - case _ =>
  • Storage
  • - } - val environment = page match { - case Environment => -
  • Environment
  • - case _ =>
  • Environment
  • - } - val executors = page match { - case Executors => -
  • Executors
  • - case _ =>
  • Executors
  • + tabs: Seq[WebUITab], + activeTab: WebUITab, + refreshInterval: Option[Int] = None): Seq[Node] = { + + val header = tabs.map { tab => +
  • + {tab.name} +
  • } @@ -74,16 +161,10 @@ private[spark] object UIUtils { - + -
    @@ -129,21 +210,36 @@ private[spark] object UIUtils { /** Returns an HTML table constructed by generating a row for each object in a sequence. */ def listingTable[T]( headers: Seq[String], - makeRow: T => Seq[Node], - rows: Seq[T], + generateDataRow: T => Seq[Node], + data: Seq[T], fixedWidth: Boolean = false): Seq[Node] = { - val colWidth = 100.toDouble / headers.size - val colWidthAttr = if (fixedWidth) colWidth + "%" else "" var tableClass = "table table-bordered table-striped table-condensed sortable" if (fixedWidth) { tableClass += " table-fixed" } - + val colWidth = 100.toDouble / headers.size + val colWidthAttr = if (fixedWidth) colWidth + "%" else "" + val headerRow: Seq[Node] = { + // if none of the headers have "\n" in them + if (headers.forall(!_.contains("\n"))) { + // represent header as simple text + headers.map(h => {h}) + } else { + // represent header text as list while respecting "\n" + headers.map { case h => + +
      + { h.split("\n").map { case t =>
    • {t}
    • } } +
    + + } + } + } - {headers.map(h => )} + {headerRow} - {rows.map(r => makeRow(r))} + {data.map(r => generateDataRow(r))}
    {h}
    } diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index a7b872f3445a4..b08f308fda1dd 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -17,34 +17,134 @@ package org.apache.spark.ui -import java.text.SimpleDateFormat -import java.util.Date +import javax.servlet.http.HttpServletRequest + +import scala.collection.mutable.ArrayBuffer +import scala.xml.Node + +import org.eclipse.jetty.servlet.ServletContextHandler +import org.json4s.JsonAST.{JNothing, JValue} + +import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.util.Utils /** - * Utilities used throughout the web UI. + * The top level component of the UI hierarchy that contains the server. + * + * Each WebUI represents a collection of tabs, each of which in turn represents a collection of + * pages. The use of tabs is optional, however; a WebUI may choose to include pages directly. */ -private[spark] object WebUI { - // SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use. - private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") +private[spark] abstract class WebUI( + securityManager: SecurityManager, + port: Int, + conf: SparkConf, + basePath: String = "") + extends Logging { + + protected val tabs = ArrayBuffer[WebUITab]() + protected val handlers = ArrayBuffer[ServletContextHandler]() + protected var serverInfo: Option[ServerInfo] = None + protected val localHostName = Utils.localHostName() + protected val publicHostName = Option(System.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName) + private val className = Utils.getFormattedClassName(this) + + def getTabs: Seq[WebUITab] = tabs.toSeq + def getHandlers: Seq[ServletContextHandler] = handlers.toSeq + + /** Attach a tab to this UI, along with all of its attached pages. */ + def attachTab(tab: WebUITab) { + tab.pages.foreach(attachPage) + tabs += tab } - def formatDate(date: Date): String = dateFormat.get.format(date) + /** Attach a page to this UI. */ + def attachPage(page: WebUIPage) { + val pagePath = "/" + page.prefix + attachHandler(createServletHandler(pagePath, + (request: HttpServletRequest) => page.render(request), securityManager, basePath)) + attachHandler(createServletHandler(pagePath.stripSuffix("/") + "/json", + (request: HttpServletRequest) => page.renderJson(request), securityManager, basePath)) + } - def formatDate(timestamp: Long): String = dateFormat.get.format(new Date(timestamp)) + /** Attach a handler to this UI. */ + def attachHandler(handler: ServletContextHandler) { + handlers += handler + serverInfo.foreach { info => + info.rootHandler.addHandler(handler) + if (!handler.isStarted) { + handler.start() + } + } + } - def formatDuration(milliseconds: Long): String = { - val seconds = milliseconds.toDouble / 1000 - if (seconds < 60) { - return "%.0f s".format(seconds) + /** Detach a handler from this UI. */ + def detachHandler(handler: ServletContextHandler) { + handlers -= handler + serverInfo.foreach { info => + info.rootHandler.removeHandler(handler) + if (handler.isStarted) { + handler.stop() + } } - val minutes = seconds / 60 - if (minutes < 10) { - return "%.1f min".format(minutes) - } else if (minutes < 60) { - return "%.0f min".format(minutes) + } + + /** Initialize all components of the server. */ + def initialize() + + /** Bind to the HTTP server behind this web interface. */ + def bind() { + assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className)) + try { + serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf)) + logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort)) + } catch { + case e: Exception => + logError("Failed to bind %s".format(className), e) + System.exit(1) } - val hours = minutes / 60 - return "%.1f h".format(hours) } + + /** Return the actual port to which this server is bound. Only valid after bind(). */ + def boundPort: Int = serverInfo.map(_.boundPort).getOrElse(-1) + + /** Stop the server behind this web interface. Only valid after bind(). */ + def stop() { + assert(serverInfo.isDefined, + "Attempted to stop %s before binding to a server!".format(className)) + serverInfo.get.server.stop() + } +} + + +/** + * A tab that represents a collection of pages. + * The prefix is appended to the parent address to form a full path, and must not contain slashes. + */ +private[spark] abstract class WebUITab(parent: WebUI, val prefix: String) { + val pages = ArrayBuffer[WebUIPage]() + val name = prefix.capitalize + + /** Attach a page to this tab. This prepends the page's prefix with the tab's own prefix. */ + def attachPage(page: WebUIPage) { + page.prefix = (prefix + "/" + page.prefix).stripSuffix("/") + pages += page + } + + /** Get a list of header tabs from the parent UI. */ + def headerTabs: Seq[WebUITab] = parent.getTabs +} + + +/** + * A page that represents the leaf node in the UI hierarchy. + * + * The direct parent of a WebUIPage is not specified as it can be either a WebUI or a WebUITab. + * If the parent is a WebUI, the prefix is appended to the parent's address to form a full path. + * Else, if the parent is a WebUITab, the prefix is appended to the super prefix of the parent + * to form a relative path. The prefix must not contain slashes. + */ +private[spark] abstract class WebUIPage(var prefix: String) { + def render(request: HttpServletRequest): Seq[Node] + def renderJson(request: HttpServletRequest): JValue = JNothing } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala similarity index 62% rename from core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala rename to core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala index 23e90c34d5b33..b347eb1b83c1f 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentPage.scala @@ -21,28 +21,12 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.eclipse.jetty.servlet.ServletContextHandler +import org.apache.spark.ui.{UIUtils, WebUIPage} -import org.apache.spark.scheduler._ -import org.apache.spark.ui._ -import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.Page.Environment - -private[ui] class EnvironmentUI(parent: SparkUI) { +private[ui] class EnvironmentPage(parent: EnvironmentTab) extends WebUIPage("") { private val appName = parent.appName private val basePath = parent.basePath - private var _listener: Option[EnvironmentListener] = None - - lazy val listener = _listener.get - - def start() { - _listener = Some(new EnvironmentListener) - } - - def getHandlers = Seq[ServletContextHandler]( - createServletHandler("/environment", - (request: HttpServletRequest) => render(request), parent.securityManager, basePath) - ) + private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val runtimeInformationTable = UIUtils.listingTable( @@ -61,7 +45,7 @@ private[ui] class EnvironmentUI(parent: SparkUI) {

    Classpath Entries

    {classpathEntriesTable} - UIUtils.headerSparkPage(content, basePath, appName, "Environment", Environment) + UIUtils.headerSparkPage(content, basePath, appName, "Environment", parent.headerTabs, parent) } private def propertyHeader = Seq("Name", "Value") @@ -70,23 +54,3 @@ private[ui] class EnvironmentUI(parent: SparkUI) { private def propertyRow(kv: (String, String)) = {kv._1}{kv._2} private def classPathRow(data: (String, String)) = {data._1}{data._2} } - -/** - * A SparkListener that prepares information to be displayed on the EnvironmentUI - */ -private[ui] class EnvironmentListener extends SparkListener { - var jvmInformation = Seq[(String, String)]() - var sparkProperties = Seq[(String, String)]() - var systemProperties = Seq[(String, String)]() - var classpathEntries = Seq[(String, String)]() - - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { - synchronized { - val environmentDetails = environmentUpdate.environmentDetails - jvmInformation = environmentDetails("JVM Information") - sparkProperties = environmentDetails("Spark Properties") - systemProperties = environmentDetails("System Properties") - classpathEntries = environmentDetails("Classpath Entries") - } - } -} diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala new file mode 100644 index 0000000000000..03b46e1bd59af --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.env + +import org.apache.spark.scheduler._ +import org.apache.spark.ui._ + +private[ui] class EnvironmentTab(parent: SparkUI) extends WebUITab(parent, "environment") { + val appName = parent.appName + val basePath = parent.basePath + val listener = new EnvironmentListener + + attachPage(new EnvironmentPage(this)) + parent.registerListener(listener) +} + +/** + * A SparkListener that prepares information to be displayed on the EnvironmentTab + */ +private[ui] class EnvironmentListener extends SparkListener { + var jvmInformation = Seq[(String, String)]() + var sparkProperties = Seq[(String, String)]() + var systemProperties = Seq[(String, String)]() + var classpathEntries = Seq[(String, String)]() + + override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { + synchronized { + val environmentDetails = environmentUpdate.environmentDetails + jvmInformation = environmentDetails("JVM Information") + sparkProperties = environmentDetails("Spark Properties") + systemProperties = environmentDetails("System Properties") + classpathEntries = environmentDetails("Classpath Entries") + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala similarity index 61% rename from core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala rename to core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 031ed88a493a8..c1e69f6cdaffb 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -19,34 +19,15 @@ package org.apache.spark.ui.exec import javax.servlet.http.HttpServletRequest -import scala.collection.mutable.HashMap import scala.xml.Node -import org.eclipse.jetty.servlet.ServletContextHandler - -import org.apache.spark.ExceptionFailure -import org.apache.spark.scheduler._ -import org.apache.spark.storage.StorageStatusListener -import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.Page.Executors -import org.apache.spark.ui.{SparkUI, UIUtils} +import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils -private[ui] class ExecutorsUI(parent: SparkUI) { +private[ui] class ExecutorsPage(parent: ExecutorsTab) extends WebUIPage("") { private val appName = parent.appName private val basePath = parent.basePath - private var _listener: Option[ExecutorsListener] = None - - lazy val listener = _listener.get - - def start() { - _listener = Some(new ExecutorsListener(parent.storageStatusListener)) - } - - def getHandlers = Seq[ServletContextHandler]( - createServletHandler("/executors", - (request: HttpServletRequest) => render(request), parent.securityManager, basePath) - ) + private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val storageStatusList = listener.storageStatusList @@ -74,8 +55,8 @@ private[ui] class ExecutorsUI(parent: SparkUI) {
    ; - UIUtils.headerSparkPage( - content, basePath, appName, "Executors (" + execInfo.size + ")", Executors) + UIUtils.headerSparkPage(content, basePath, appName, "Executors (" + execInfo.size + ")", + parent.headerTabs, parent) } /** Header fields for the executors table */ @@ -158,55 +139,3 @@ private[ui] class ExecutorsUI(parent: SparkUI) { execFields.zip(execValues).toMap } } - -/** - * A SparkListener that prepares information to be displayed on the ExecutorsUI - */ -private[ui] class ExecutorsListener(storageStatusListener: StorageStatusListener) - extends SparkListener { - - val executorToTasksActive = HashMap[String, Int]() - val executorToTasksComplete = HashMap[String, Int]() - val executorToTasksFailed = HashMap[String, Int]() - val executorToDuration = HashMap[String, Long]() - val executorToShuffleRead = HashMap[String, Long]() - val executorToShuffleWrite = HashMap[String, Long]() - - def storageStatusList = storageStatusListener.storageStatusList - - override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { - val eid = formatExecutorId(taskStart.taskInfo.executorId) - executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 - } - - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { - val info = taskEnd.taskInfo - if (info != null) { - val eid = formatExecutorId(info.executorId) - executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 - executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration - taskEnd.reason match { - case e: ExceptionFailure => - executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 - case _ => - executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 - } - - // Update shuffle read/write - val metrics = taskEnd.taskMetrics - if (metrics != null) { - metrics.shuffleReadMetrics.foreach { shuffleRead => - executorToShuffleRead(eid) = - executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead - } - metrics.shuffleWriteMetrics.foreach { shuffleWrite => - executorToShuffleWrite(eid) = - executorToShuffleWrite.getOrElse(eid, 0L) + shuffleWrite.shuffleBytesWritten - } - } - } - } - - // This addresses executor ID inconsistencies in the local mode - private def formatExecutorId(execId: String) = storageStatusListener.formatExecutorId(execId) -} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala new file mode 100644 index 0000000000000..5678bf34ac730 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.exec + +import scala.collection.mutable.HashMap + +import org.apache.spark.ExceptionFailure +import org.apache.spark.scheduler._ +import org.apache.spark.storage.StorageStatusListener +import org.apache.spark.ui.{SparkUI, WebUITab} + +private[ui] class ExecutorsTab(parent: SparkUI) extends WebUITab(parent, "executors") { + val appName = parent.appName + val basePath = parent.basePath + val listener = new ExecutorsListener(parent.storageStatusListener) + + attachPage(new ExecutorsPage(this)) + parent.registerListener(listener) +} + +/** + * A SparkListener that prepares information to be displayed on the ExecutorsTab + */ +private[ui] class ExecutorsListener(storageStatusListener: StorageStatusListener) + extends SparkListener { + + val executorToTasksActive = HashMap[String, Int]() + val executorToTasksComplete = HashMap[String, Int]() + val executorToTasksFailed = HashMap[String, Int]() + val executorToDuration = HashMap[String, Long]() + val executorToShuffleRead = HashMap[String, Long]() + val executorToShuffleWrite = HashMap[String, Long]() + + def storageStatusList = storageStatusListener.storageStatusList + + override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { + val eid = formatExecutorId(taskStart.taskInfo.executorId) + executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 0) + 1 + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + val info = taskEnd.taskInfo + if (info != null) { + val eid = formatExecutorId(info.executorId) + executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 + executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration + taskEnd.reason match { + case e: ExceptionFailure => + executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 + case _ => + executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 + } + + // Update shuffle read/write + val metrics = taskEnd.taskMetrics + if (metrics != null) { + metrics.shuffleReadMetrics.foreach { shuffleRead => + executorToShuffleRead(eid) = + executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead + } + metrics.shuffleWriteMetrics.foreach { shuffleWrite => + executorToShuffleWrite(eid) = + executorToShuffleWrite.getOrElse(eid, 0L) + shuffleWrite.shuffleBytesWritten + } + } + } + } + + // This addresses executor ID inconsistencies in the local mode + private def formatExecutorId(execId: String) = storageStatusListener.formatExecutorId(execId) +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 73861ae6746da..c83e196c9c156 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -20,11 +20,12 @@ package org.apache.spark.ui.jobs import scala.collection.mutable import scala.xml.Node +import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils /** Page showing executor summary */ -private[ui] class ExecutorTable(stageId: Int, parent: JobProgressUI) { - private lazy val listener = parent.listener +private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) { + private val listener = parent.listener def toNodeSeq: Seq[Node] = { listener.synchronized { @@ -69,7 +70,7 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressUI) { {k} {executorIdToAddress.getOrElse(k, "CANNOT FIND ADDRESS")} - {parent.formatDuration(v.taskTime)} + {UIUtils.formatDuration(v.taskTime)} {v.failedTasks + v.succeededTasks} {v.failedTasks} {v.succeededTasks} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index d10aa12b9ebca..0db4afa701b41 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -74,16 +74,20 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener { // Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage poolToActiveStages(stageIdToPool(stageId)).remove(stageId) activeStages.remove(stageId) - completedStages += stage - trimIfNecessary(completedStages) + if (stage.failureReason.isEmpty) { + completedStages += stage + trimIfNecessary(completedStages) + } else { + failedStages += stage + trimIfNecessary(failedStages) + } } /** If stages is too large, remove and garbage collect old stages */ private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized { if (stages.size > retainedStages) { - val toRemove = retainedStages / 10 - stages.takeRight(toRemove).foreach( s => { - stageIdToTaskData.remove(s.stageId) + val toRemove = math.max(retainedStages / 10, 1) + stages.take(toRemove).foreach { s => stageIdToTime.remove(s.stageId) stageIdToShuffleRead.remove(s.stageId) stageIdToShuffleWrite.remove(s.stageId) @@ -92,10 +96,12 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener { stageIdToTasksActive.remove(s.stageId) stageIdToTasksComplete.remove(s.stageId) stageIdToTasksFailed.remove(s.stageId) + stageIdToTaskData.remove(s.stageId) + stageIdToExecutorSummaries.remove(s.stageId) stageIdToPool.remove(s.stageId) - if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)} - }) - stages.trimEnd(toRemove) + stageIdToDescription.remove(s.stageId) + } + stages.trimStart(toRemove) } } @@ -214,28 +220,12 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener { } } - override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized { - jobEnd.jobResult match { - case JobFailed(_, stageId) => - activeStages.get(stageId).foreach { s => - // Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage - activeStages.remove(s.stageId) - poolToActiveStages(stageIdToPool(stageId)).remove(s.stageId) - failedStages += s - trimIfNecessary(failedStages) - } - case _ => - } - } - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { synchronized { - val schedulingModeName = - environmentUpdate.environmentDetails("Spark Properties").toMap.get("spark.scheduler.mode") - schedulingMode = schedulingModeName match { - case Some(name) => Some(SchedulingMode.withName(name)) - case None => None - } + schedulingMode = environmentUpdate + .environmentDetails("Spark Properties").toMap + .get("spark.scheduler.mode") + .map(SchedulingMode.withName) } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala similarity index 88% rename from core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala rename to core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala index 70d62b66a4829..34ff2ac34a7ca 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala @@ -22,16 +22,15 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, NodeSeq} import org.apache.spark.scheduler.Schedulable -import org.apache.spark.ui.Page._ -import org.apache.spark.ui.UIUtils +import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing list of all ongoing and recently finished stages and pools */ -private[ui] class IndexPage(parent: JobProgressUI) { +private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") { private val appName = parent.appName private val basePath = parent.basePath private val live = parent.live private val sc = parent.sc - private lazy val listener = parent.listener + private val listener = parent.listener private lazy val isFairScheduler = parent.isFairScheduler def render(request: HttpServletRequest): Seq[Node] = { @@ -39,9 +38,10 @@ private[ui] class IndexPage(parent: JobProgressUI) { val activeStages = listener.activeStages.values.toSeq val completedStages = listener.completedStages.reverse.toSeq val failedStages = listener.failedStages.reverse.toSeq - val now = System.currentTimeMillis() + val now = System.currentTimeMillis - val activeStagesTable = new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent) + val activeStagesTable = + new StageTable(activeStages.sortBy(_.submissionTime).reverse, parent, parent.killEnabled) val completedStagesTable = new StageTable(completedStages.sortBy(_.submissionTime).reverse, parent) val failedStagesTable = new StageTable(failedStages.sortBy(_.submissionTime).reverse, parent) @@ -57,7 +57,7 @@ private[ui] class IndexPage(parent: JobProgressUI) { // Total duration is not meaningful unless the UI is live
  • Total Duration: - {parent.formatDuration(now - sc.startTime)} + {UIUtils.formatDuration(now - sc.startTime)}
  • }}
  • @@ -92,7 +92,7 @@ private[ui] class IndexPage(parent: JobProgressUI) {

    Failed Stages ({failedStages.size})

    ++ failedStagesTable.toNodeSeq - UIUtils.headerSparkPage(content, basePath, appName, "Spark Stages", Stages) + UIUtils.headerSparkPage(content, basePath, appName, "Spark Stages", parent.headerTabs, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala new file mode 100644 index 0000000000000..3308c8c8a3d37 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.jobs + +import javax.servlet.http.HttpServletRequest + +import org.apache.spark.SparkConf +import org.apache.spark.scheduler.SchedulingMode +import org.apache.spark.ui.{SparkUI, WebUITab} + +/** Web UI showing progress status of all jobs in the given SparkContext. */ +private[ui] class JobProgressTab(parent: SparkUI) extends WebUITab(parent, "stages") { + val appName = parent.appName + val basePath = parent.basePath + val live = parent.live + val sc = parent.sc + val conf = if (live) sc.conf else new SparkConf + val killEnabled = conf.getBoolean("spark.ui.killEnabled", true) + val listener = new JobProgressListener(conf) + + attachPage(new JobProgressPage(this)) + attachPage(new StagePage(this)) + attachPage(new PoolPage(this)) + parent.registerListener(listener) + + def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) + + def handleKillRequest(request: HttpServletRequest) = { + if (killEnabled) { + val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean + val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt + if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) { + sc.cancelStage(stageId) + } + // Do a quick pause here to give Spark time to kill the stage so it shows up as + // killed after the refresh. Note that this will block the serving thread so the + // time should be limited in duration. + Thread.sleep(100) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala deleted file mode 100644 index b2c67381cc3da..0000000000000 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.ui.jobs - -import javax.servlet.http.HttpServletRequest - -import org.eclipse.jetty.servlet.ServletContextHandler - -import org.apache.spark.SparkConf -import org.apache.spark.scheduler.SchedulingMode -import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.SparkUI -import org.apache.spark.util.Utils - -/** Web UI showing progress status of all jobs in the given SparkContext. */ -private[ui] class JobProgressUI(parent: SparkUI) { - val appName = parent.appName - val basePath = parent.basePath - val live = parent.live - val sc = parent.sc - - lazy val listener = _listener.get - lazy val isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR) - - private val indexPage = new IndexPage(this) - private val stagePage = new StagePage(this) - private val poolPage = new PoolPage(this) - private var _listener: Option[JobProgressListener] = None - - def start() { - val conf = if (live) sc.conf else new SparkConf - _listener = Some(new JobProgressListener(conf)) - } - - def formatDuration(ms: Long) = Utils.msDurationToString(ms) - - def getHandlers = Seq[ServletContextHandler]( - createServletHandler("/stages/stage", - (request: HttpServletRequest) => stagePage.render(request), parent.securityManager, basePath), - createServletHandler("/stages/pool", - (request: HttpServletRequest) => poolPage.render(request), parent.securityManager, basePath), - createServletHandler("/stages", - (request: HttpServletRequest) => indexPage.render(request), parent.securityManager, basePath) - ) -} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index bd33182b70059..fd83d37583967 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -22,16 +22,15 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.scheduler.{Schedulable, StageInfo} -import org.apache.spark.ui.Page._ -import org.apache.spark.ui.UIUtils +import org.apache.spark.ui.{WebUIPage, UIUtils} /** Page showing specific pool details */ -private[ui] class PoolPage(parent: JobProgressUI) { +private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") { private val appName = parent.appName private val basePath = parent.basePath private val live = parent.live private val sc = parent.sc - private lazy val listener = parent.listener + private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { @@ -51,8 +50,8 @@ private[ui] class PoolPage(parent: JobProgressUI) {

    Summary

    ++ poolTable.toNodeSeq ++

    {activeStages.size} Active Stages

    ++ activeStagesTable.toNodeSeq - UIUtils.headerSparkPage( - content, basePath, appName, "Fair Scheduler Pool: " + poolName, Stages) + UIUtils.headerSparkPage(content, basePath, appName, "Fair Scheduler Pool: " + poolName, + parent.headerTabs, parent) } } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala index c5c8d8668740b..f4b68f241966d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala @@ -24,10 +24,9 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo} import org.apache.spark.ui.UIUtils /** Table showing list of pools */ -private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressUI) { +private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) { private val basePath = parent.basePath - private val poolToActiveStages = listener.poolToActiveStages - private lazy val listener = parent.listener + private val listener = parent.listener def toNodeSeq: Seq[Node] = { listener.synchronized { @@ -48,7 +47,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressUI) { SchedulingMode - {rows.map(r => makeRow(r, poolToActiveStages))} + {rows.map(r => makeRow(r, listener.poolToActiveStages))} } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 0c55f2ee7e944..4bce472036f7d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -22,15 +22,14 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.ui.Page._ -import org.apache.spark.ui.{WebUI, UIUtils} +import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.{Utils, Distribution} /** Page showing statistics and task list for a given stage */ -private[ui] class StagePage(parent: JobProgressUI) { +private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") { private val appName = parent.appName private val basePath = parent.basePath - private lazy val listener = parent.listener + private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { listener.synchronized { @@ -42,8 +41,8 @@ private[ui] class StagePage(parent: JobProgressUI) {

    Summary Metrics

    No tasks have started yet

    Tasks

    No tasks have started yet
  • - return UIUtils.headerSparkPage( - content, basePath, appName, "Details for Stage %s".format(stageId), Stages) + return UIUtils.headerSparkPage(content, basePath, appName, + "Details for Stage %s".format(stageId), parent.headerTabs, parent) } val tasks = listener.stageIdToTaskData(stageId).values.toSeq.sortBy(_.taskInfo.launchTime) @@ -58,7 +57,7 @@ private[ui] class StagePage(parent: JobProgressUI) { val hasBytesSpilled = memoryBytesSpilled > 0 && diskBytesSpilled > 0 var activeTime = 0L - val now = System.currentTimeMillis() + val now = System.currentTimeMillis val tasksActive = listener.stageIdToTasksActive(stageId).values tasksActive.foreach(activeTime += _.timeRunning(now)) @@ -68,7 +67,7 @@ private[ui] class StagePage(parent: JobProgressUI) {
    • Total task time across all tasks: - {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)} + {UIUtils.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
    • {if (hasShuffleRead)
    • @@ -119,13 +118,13 @@ private[ui] class StagePage(parent: JobProgressUI) { } val serializationQuantiles = "Result serialization time" +: Distribution(serializationTimes). - get.getQuantiles().map(ms => parent.formatDuration(ms.toLong)) + get.getQuantiles().map(ms => UIUtils.formatDuration(ms.toLong)) val serviceTimes = validTasks.map { case TaskUIData(_, metrics, _) => metrics.get.executorRunTime.toDouble } val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles() - .map(ms => parent.formatDuration(ms.toLong)) + .map(ms => UIUtils.formatDuration(ms.toLong)) val gettingResultTimes = validTasks.map { case TaskUIData(info, _, _) => if (info.gettingResultTime > 0) { @@ -136,7 +135,7 @@ private[ui] class StagePage(parent: JobProgressUI) { } val gettingResultQuantiles = "Time spent fetching task results" +: Distribution(gettingResultTimes).get.getQuantiles().map { millis => - parent.formatDuration(millis.toLong) + UIUtils.formatDuration(millis.toLong) } // The scheduler delay includes the network delay to send the task to the worker // machine and to send back the result (but not the time to fetch the task result, @@ -153,7 +152,7 @@ private[ui] class StagePage(parent: JobProgressUI) { } val schedulerDelayQuantiles = "Scheduler delay" +: Distribution(schedulerDelays).get.getQuantiles().map { millis => - parent.formatDuration(millis.toLong) + UIUtils.formatDuration(millis.toLong) } def getQuantileCols(data: Seq[Double]) = @@ -204,8 +203,8 @@ private[ui] class StagePage(parent: JobProgressUI) {

      Aggregated Metrics by Executor

      ++ executorTable.toNodeSeq ++

      Tasks

      ++ taskTable - UIUtils.headerSparkPage( - content, basePath, appName, "Details for Stage %d".format(stageId), Stages) + UIUtils.headerSparkPage(content, basePath, appName, "Details for Stage %d".format(stageId), + parent.headerTabs, parent) } } @@ -217,8 +216,8 @@ private[ui] class StagePage(parent: JobProgressUI) { taskData match { case TaskUIData(info, metrics, exception) => val duration = if (info.status == "RUNNING") info.timeRunning(System.currentTimeMillis()) else metrics.map(_.executorRunTime).getOrElse(1L) - val formatDuration = if (info.status == "RUNNING") parent.formatDuration(duration) - else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") + val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) + else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) @@ -233,8 +232,8 @@ private[ui] class StagePage(parent: JobProgressUI) { val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") - val writeTimeReadable = maybeWriteTime.map( t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else parent.formatDuration(ms) + val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => + if (ms == 0) "" else UIUtils.formatDuration(ms) }.getOrElse("") val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) @@ -252,15 +251,15 @@ private[ui] class StagePage(parent: JobProgressUI) { {info.status} {info.taskLocality} {info.host} - {WebUI.formatDate(new Date(info.launchTime))} + {UIUtils.formatDate(new Date(info.launchTime))} {formatDuration} - {if (gcTime > 0) parent.formatDuration(gcTime) else ""} + {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} - {if (serializationTime > 0) parent.formatDuration(serializationTime) else ""} + {if (serializationTime > 0) UIUtils.formatDuration(serializationTime) else ""} {if (shuffleRead) { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala index ac61568af52d2..8c5b1f55fd2dc 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala @@ -23,13 +23,17 @@ import scala.collection.mutable.HashMap import scala.xml.Node import org.apache.spark.scheduler.{StageInfo, TaskInfo} -import org.apache.spark.ui.{WebUI, UIUtils} +import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils /** Page showing list of all ongoing and recently finished stages */ -private[ui] class StageTable(stages: Seq[StageInfo], parent: JobProgressUI) { +private[ui] class StageTable( + stages: Seq[StageInfo], + parent: JobProgressTab, + killEnabled: Boolean = false) { + private val basePath = parent.basePath - private lazy val listener = parent.listener + private val listener = parent.listener private lazy val isFairScheduler = parent.isFairScheduler def toNodeSeq: Seq[Node] = { @@ -71,24 +75,37 @@ private[ui] class StageTable(stages: Seq[StageInfo], parent: JobProgressUI) { } - /** Render an HTML row that represents a stage */ - private def stageRow(s: StageInfo): Seq[Node] = { - val poolName = listener.stageIdToPool.get(s.stageId) + private def makeDescription(s: StageInfo): Seq[Node] = { + // scalastyle:off + val killLink = if (killEnabled) { + + (kill) + + } + // scalastyle:on + val nameLink = {s.name} - val description = listener.stageIdToDescription.get(s.stageId) - .map(d =>
      {d}
      {nameLink}
      ).getOrElse(nameLink) + + listener.stageIdToDescription.get(s.stageId) + .map(d =>
      {d}
      {nameLink} {killLink}
      ) + .getOrElse(
      {killLink}{nameLink}
      ) + } + + /** Render an HTML row that represents a stage */ + private def stageRow(s: StageInfo): Seq[Node] = { + val poolName = listener.stageIdToPool.get(s.stageId) val submissionTime = s.submissionTime match { - case Some(t) => WebUI.formatDate(new Date(t)) + case Some(t) => UIUtils.formatDate(new Date(t)) case None => "Unknown" } val finishTime = s.completionTime.getOrElse(System.currentTimeMillis) val duration = s.submissionTime.map { t => if (finishTime > t) finishTime - t else System.currentTimeMillis - t } - val formattedDuration = duration.map(d => parent.formatDuration(d)).getOrElse("Unknown") + val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashMap[Long, TaskInfo]()).size val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0) @@ -118,7 +135,7 @@ private[ui] class StageTable(stages: Seq[StageInfo], parent: JobProgressUI) { }} - {description} + {makeDescription(s)} {submissionTime} {formattedDuration} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index 3f42eba4ece00..d07f1c9b20fcf 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -22,22 +22,22 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.storage.{BlockId, BlockStatus, StorageStatus, StorageUtils} -import org.apache.spark.ui.Page._ -import org.apache.spark.ui.UIUtils +import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils /** Page showing storage details for a given RDD */ -private[ui] class RDDPage(parent: BlockManagerUI) { +private[ui] class RddPage(parent: StorageTab) extends WebUIPage("rdd") { private val appName = parent.appName private val basePath = parent.basePath - private lazy val listener = parent.listener + private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val rddId = request.getParameter("id").toInt val storageStatusList = listener.storageStatusList val rddInfo = listener.rddInfoList.find(_.id == rddId).getOrElse { // Rather than crashing, render an "RDD Not Found" page - return UIUtils.headerSparkPage(Seq[Node](), basePath, appName, "RDD Not Found", Storage) + return UIUtils.headerSparkPage(Seq[Node](), basePath, appName, "RDD Not Found", + parent.headerTabs, parent) } // Worker table @@ -95,8 +95,8 @@ private[ui] class RDDPage(parent: BlockManagerUI) { ; - UIUtils.headerSparkPage( - content, basePath, appName, "RDD Storage Info for " + rddInfo.name, Storage) + UIUtils.headerSparkPage(content, basePath, appName, "RDD Storage Info for " + rddInfo.name, + parent.headerTabs, parent) } /** Header fields for the worker table */ diff --git a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala similarity index 88% rename from core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala rename to core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index b2732de51058a..b66edd91f56c0 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -22,20 +22,19 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.storage.RDDInfo -import org.apache.spark.ui.Page._ -import org.apache.spark.ui.UIUtils +import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils /** Page showing list of RDD's currently stored in the cluster */ -private[ui] class IndexPage(parent: BlockManagerUI) { +private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") { private val appName = parent.appName private val basePath = parent.basePath - private lazy val listener = parent.listener + private val listener = parent.listener def render(request: HttpServletRequest): Seq[Node] = { val rdds = listener.rddInfoList val content = UIUtils.listingTable(rddHeader, rddRow, rdds) - UIUtils.headerSparkPage(content, basePath, appName, "Storage ", Storage) + UIUtils.headerSparkPage(content, basePath, appName, "Storage ", parent.headerTabs, parent) } /** Header fields for the RDD table */ @@ -45,6 +44,7 @@ private[ui] class IndexPage(parent: BlockManagerUI) { "Cached Partitions", "Fraction Cached", "Size in Memory", + "Size in Tachyon", "Size on Disk") /** Render an HTML row representing an RDD */ @@ -60,6 +60,7 @@ private[ui] class IndexPage(parent: BlockManagerUI) { {rdd.numCachedPartitions} {"%.0f%%".format(rdd.numCachedPartitions * 100.0 / rdd.numPartitions)} {Utils.bytesToString(rdd.memSize)} + {Utils.bytesToString(rdd.tachyonSize)} {Utils.bytesToString(rdd.diskSize)} } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala similarity index 76% rename from core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala rename to core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index a7b24ff695214..56429f6c07fcd 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -17,44 +17,27 @@ package org.apache.spark.ui.storage -import javax.servlet.http.HttpServletRequest - import scala.collection.mutable -import org.eclipse.jetty.servlet.ServletContextHandler - import org.apache.spark.ui._ -import org.apache.spark.ui.JettyUtils._ import org.apache.spark.scheduler._ import org.apache.spark.storage.{RDDInfo, StorageStatusListener, StorageUtils} /** Web UI showing storage status of all RDD's in the given SparkContext. */ -private[ui] class BlockManagerUI(parent: SparkUI) { +private[ui] class StorageTab(parent: SparkUI) extends WebUITab(parent, "storage") { val appName = parent.appName val basePath = parent.basePath + val listener = new StorageListener(parent.storageStatusListener) - private val indexPage = new IndexPage(this) - private val rddPage = new RDDPage(this) - private var _listener: Option[BlockManagerListener] = None - - lazy val listener = _listener.get - - def start() { - _listener = Some(new BlockManagerListener(parent.storageStatusListener)) - } - - def getHandlers = Seq[ServletContextHandler]( - createServletHandler("/storage/rdd", - (request: HttpServletRequest) => rddPage.render(request), parent.securityManager, basePath), - createServletHandler("/storage", - (request: HttpServletRequest) => indexPage.render(request), parent.securityManager, basePath) - ) + attachPage(new StoragePage(this)) + attachPage(new RddPage(this)) + parent.registerListener(listener) } /** * A SparkListener that prepares information to be displayed on the BlockManagerUI */ -private[ui] class BlockManagerListener(storageStatusListener: StorageStatusListener) +private[ui] class StorageListener(storageStatusListener: StorageStatusListener) extends SparkListener { private val _rddInfoMap = mutable.Map[Int, RDDInfo]() diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala index c3692f2fd929b..b9f4a5d720b93 100644 --- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -28,7 +28,7 @@ import scala.collection.generic.Growable * class and modifies it such that only the top K elements are retained. * The top K elements are defined by an implicit Ordering[A]. */ -class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) +private[spark] class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) extends Iterable[A] with Growable[A] with Serializable { private val underlying = new JPriorityQueue[A](maxSize, ord) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index cdbbc65292188..2d05e09b10948 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -45,7 +45,7 @@ private[spark] object ClosureCleaner extends Logging { private def isClosure(cls: Class[_]): Boolean = { cls.getName.contains("$anonfun$") } - + // Get a list of the classes of the outer objects of a given closure object, obj; // the outer objects are defined as any closures that obj is nested within, plus // possibly the class that the outermost closure is in, if any. We stop searching @@ -63,7 +63,7 @@ private[spark] object ClosureCleaner extends Logging { } Nil } - + // Get a list of the outer objects for a given closure object. private def getOuterObjects(obj: AnyRef): List[AnyRef] = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { @@ -76,7 +76,7 @@ private[spark] object ClosureCleaner extends Logging { } Nil } - + private def getInnerClasses(obj: AnyRef): List[Class[_]] = { val seen = Set[Class[_]](obj.getClass) var stack = List[Class[_]](obj.getClass) @@ -92,7 +92,7 @@ private[spark] object ClosureCleaner extends Logging { } return (seen - obj.getClass).toList } - + private def createNullValue(cls: Class[_]): AnyRef = { if (cls.isPrimitive) { new java.lang.Byte(0: Byte) // Should be convertible to any primitive type @@ -100,13 +100,13 @@ private[spark] object ClosureCleaner extends Logging { null } } - + def clean(func: AnyRef) { // TODO: cache outerClasses / innerClasses / accessedFields val outerClasses = getOuterClasses(func) val innerClasses = getInnerClasses(func) val outerObjects = getOuterObjects(func) - + val accessedFields = Map[Class[_], Set[String]]() for (cls <- outerClasses) accessedFields(cls) = Set[String]() @@ -143,7 +143,7 @@ private[spark] object ClosureCleaner extends Logging { field.set(outer, value) } } - + if (outer != null) { // logInfo("2: Setting $outer on " + func.getClass + " to " + outer); val field = func.getClass.getDeclaredField("$outer") @@ -151,7 +151,7 @@ private[spark] object ClosureCleaner extends Logging { field.set(func, outer) } } - + private def instantiateClass(cls: Class[_], outer: AnyRef, inInterpreter: Boolean): AnyRef = { // logInfo("Creating a " + cls + " with outer = " + outer) if (!inInterpreter) { @@ -192,7 +192,7 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor } } } - + override def visitMethodInsn(op: Int, owner: String, name: String, desc: String) { // Check for calls a getter method for a variable in an interpreter wrapper object. @@ -209,12 +209,12 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) { var myName: String = null - + override def visit(version: Int, access: Int, name: String, sig: String, superName: String, interfaces: Array[String]) { myName = name } - + override def visitMethod(access: Int, name: String, desc: String, sig: String, exceptions: Array[String]): MethodVisitor = { new MethodVisitor(ASM4) { diff --git a/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala b/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala index db3db87e6618e..93235031f3ad5 100644 --- a/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala +++ b/core/src/main/scala/org/apache/spark/util/CollectionsUtil.scala @@ -22,7 +22,7 @@ import java.util import scala.Array import scala.reflect._ -object CollectionsUtils { +private[spark] object CollectionsUtils { def makeBinarySearch[K <% Ordered[K] : ClassTag] : (Array[K], K) => Int = { classTag[K] match { case ClassTag.Float => diff --git a/core/src/main/scala/org/apache/spark/util/Distribution.scala b/core/src/main/scala/org/apache/spark/util/Distribution.scala index 5b347555fe708..a465298c8c5ab 100644 --- a/core/src/main/scala/org/apache/spark/util/Distribution.scala +++ b/core/src/main/scala/org/apache/spark/util/Distribution.scala @@ -29,7 +29,7 @@ import scala.collection.immutable.IndexedSeq * * Assumes you are giving it a non-empty set of data */ -class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) { +private[spark] class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) { require(startIdx < endIdx) def this(data: Traversable[Double]) = this(data.toArray, 0, data.size) java.util.Arrays.sort(data, startIdx, endIdx) @@ -69,7 +69,7 @@ class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) } } -object Distribution { +private[spark] object Distribution { def apply(data: Traversable[Double]): Option[Distribution] = { if (data.size > 0) { diff --git a/core/src/main/scala/org/apache/spark/util/FileLogger.scala b/core/src/main/scala/org/apache/spark/util/FileLogger.scala index a0c07e32fdc98..68a12e8ed67d7 100644 --- a/core/src/main/scala/org/apache/spark/util/FileLogger.scala +++ b/core/src/main/scala/org/apache/spark/util/FileLogger.scala @@ -17,12 +17,11 @@ package org.apache.spark.util -import java.io._ +import java.io.{FileOutputStream, BufferedOutputStream, PrintWriter, IOException} import java.net.URI import java.text.SimpleDateFormat import java.util.Date -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream import org.apache.hadoop.fs.{FSDataOutputStream, Path} import org.apache.spark.{Logging, SparkConf} @@ -36,7 +35,7 @@ import org.apache.spark.io.CompressionCodec * @param compress Whether to compress output * @param overwrite Whether to overwrite existing files */ -class FileLogger( +private[spark] class FileLogger( logDir: String, conf: SparkConf = new SparkConf, outputBufferSize: Int = 8 * 1024, // 8 KB @@ -49,7 +48,7 @@ class FileLogger( } private val fileSystem = Utils.getHadoopFileSystem(new URI(logDir)) - private var fileIndex = 0 + var fileIndex = 0 // Only used if compression is enabled private lazy val compressionCodec = CompressionCodec.createCodec(conf) @@ -57,10 +56,9 @@ class FileLogger( // Only defined if the file system scheme is not local private var hadoopDataStream: Option[FSDataOutputStream] = None - private var writer: Option[PrintWriter] = { - createLogDir() - Some(createWriter()) - } + private var writer: Option[PrintWriter] = None + + createLogDir() /** * Create a logging directory with the given path. @@ -84,8 +82,8 @@ class FileLogger( /** * Create a new writer for the file identified by the given path. */ - private def createWriter(): PrintWriter = { - val logPath = logDir + "/" + fileIndex + private def createWriter(fileName: String): PrintWriter = { + val logPath = logDir + "/" + fileName val uri = new URI(logPath) /* The Hadoop LocalFileSystem (r1.0.4) has known issues with syncing (HADOOP-7844). @@ -101,7 +99,7 @@ class FileLogger( hadoopDataStream.get } - val bstream = new FastBufferedOutputStream(dstream, outputBufferSize) + val bstream = new BufferedOutputStream(dstream, outputBufferSize) val cstream = if (compress) compressionCodec.compressedOutputStream(bstream) else bstream new PrintWriter(cstream) } @@ -147,13 +145,17 @@ class FileLogger( } /** - * Start a writer for a new file if one does not already exit. + * Start a writer for a new file, closing the existing one if it exists. + * @param fileName Name of the new file, defaulting to the file index if not provided. */ - def start() { - writer.getOrElse { - fileIndex += 1 - writer = Some(createWriter()) + def newFile(fileName: String = "") { + fileIndex += 1 + writer.foreach(_.close()) + val name = fileName match { + case "" => fileIndex.toString + case _ => fileName } + writer = Some(createWriter(name)) } /** diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 346f2b7856791..465835ea7fe29 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -62,6 +62,10 @@ private[spark] object JsonProtocol { blockManagerRemovedToJson(blockManagerRemoved) case unpersistRDD: SparkListenerUnpersistRDD => unpersistRDDToJson(unpersistRDD) + case applicationStart: SparkListenerApplicationStart => + applicationStartToJson(applicationStart) + case applicationEnd: SparkListenerApplicationEnd => + applicationEndToJson(applicationEnd) // Not used, but keeps compiler happy case SparkListenerShutdown => JNothing @@ -84,30 +88,27 @@ private[spark] object JsonProtocol { def taskStartToJson(taskStart: SparkListenerTaskStart): JValue = { val taskInfo = taskStart.taskInfo - val taskInfoJson = if (taskInfo != null) taskInfoToJson(taskInfo) else JNothing ("Event" -> Utils.getFormattedClassName(taskStart)) ~ ("Stage ID" -> taskStart.stageId) ~ - ("Task Info" -> taskInfoJson) + ("Task Info" -> taskInfoToJson(taskInfo)) } def taskGettingResultToJson(taskGettingResult: SparkListenerTaskGettingResult): JValue = { val taskInfo = taskGettingResult.taskInfo - val taskInfoJson = if (taskInfo != null) taskInfoToJson(taskInfo) else JNothing ("Event" -> Utils.getFormattedClassName(taskGettingResult)) ~ - ("Task Info" -> taskInfoJson) + ("Task Info" -> taskInfoToJson(taskInfo)) } def taskEndToJson(taskEnd: SparkListenerTaskEnd): JValue = { val taskEndReason = taskEndReasonToJson(taskEnd.reason) val taskInfo = taskEnd.taskInfo - val taskInfoJson = if (taskInfo != null) taskInfoToJson(taskInfo) else JNothing val taskMetrics = taskEnd.taskMetrics val taskMetricsJson = if (taskMetrics != null) taskMetricsToJson(taskMetrics) else JNothing ("Event" -> Utils.getFormattedClassName(taskEnd)) ~ ("Stage ID" -> taskEnd.stageId) ~ ("Task Type" -> taskEnd.taskType) ~ ("Task End Reason" -> taskEndReason) ~ - ("Task Info" -> taskInfoJson) ~ + ("Task Info" -> taskInfoToJson(taskInfo)) ~ ("Task Metrics" -> taskMetricsJson) } @@ -157,6 +158,18 @@ private[spark] object JsonProtocol { ("RDD ID" -> unpersistRDD.rddId) } + def applicationStartToJson(applicationStart: SparkListenerApplicationStart): JValue = { + ("Event" -> Utils.getFormattedClassName(applicationStart)) ~ + ("App Name" -> applicationStart.appName) ~ + ("Timestamp" -> applicationStart.time) ~ + ("User" -> applicationStart.sparkUser) + } + + def applicationEndToJson(applicationEnd: SparkListenerApplicationEnd): JValue = { + ("Event" -> Utils.getFormattedClassName(applicationEnd)) ~ + ("Timestamp" -> applicationEnd.time) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | @@ -166,12 +179,14 @@ private[spark] object JsonProtocol { val rddInfo = rddInfoToJson(stageInfo.rddInfo) val submissionTime = stageInfo.submissionTime.map(JInt(_)).getOrElse(JNothing) val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing) + val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing) ("Stage ID" -> stageInfo.stageId) ~ ("Stage Name" -> stageInfo.name) ~ ("Number of Tasks" -> stageInfo.numTasks) ~ ("RDD Info" -> rddInfo) ~ ("Submission Time" -> submissionTime) ~ ("Completion Time" -> completionTime) ~ + ("Failure Reason" -> failureReason) ~ ("Emitted Task Size Warning" -> stageInfo.emittedTaskSizeWarning) } @@ -195,7 +210,7 @@ private[spark] object JsonProtocol { taskMetrics.shuffleWriteMetrics.map(shuffleWriteMetricsToJson).getOrElse(JNothing) val updatedBlocks = taskMetrics.updatedBlocks.map { blocks => JArray(blocks.toList.map { case (id, status) => - ("Block ID" -> blockIdToJson(id)) ~ + ("Block ID" -> id.toString) ~ ("Status" -> blockStatusToJson(status)) }) }.getOrElse(JNothing) @@ -259,9 +274,7 @@ private[spark] object JsonProtocol { val json = jobResult match { case JobSucceeded => Utils.emptyJson case jobFailed: JobFailed => - val exception = exceptionToJson(jobFailed.exception) - ("Exception" -> exception) ~ - ("Failed Stage ID" -> jobFailed.failedStageId) + JObject("Exception" -> exceptionToJson(jobFailed.exception)) } ("Result" -> result) ~ json } @@ -274,49 +287,23 @@ private[spark] object JsonProtocol { ("Number of Partitions" -> rddInfo.numPartitions) ~ ("Number of Cached Partitions" -> rddInfo.numCachedPartitions) ~ ("Memory Size" -> rddInfo.memSize) ~ + ("Tachyon Size" -> rddInfo.tachyonSize) ~ ("Disk Size" -> rddInfo.diskSize) } def storageLevelToJson(storageLevel: StorageLevel): JValue = { ("Use Disk" -> storageLevel.useDisk) ~ ("Use Memory" -> storageLevel.useMemory) ~ + ("Use Tachyon" -> storageLevel.useOffHeap) ~ ("Deserialized" -> storageLevel.deserialized) ~ ("Replication" -> storageLevel.replication) } - def blockIdToJson(blockId: BlockId): JValue = { - val blockType = Utils.getFormattedClassName(blockId) - val json: JObject = blockId match { - case rddBlockId: RDDBlockId => - ("RDD ID" -> rddBlockId.rddId) ~ - ("Split Index" -> rddBlockId.splitIndex) - case shuffleBlockId: ShuffleBlockId => - ("Shuffle ID" -> shuffleBlockId.shuffleId) ~ - ("Map ID" -> shuffleBlockId.mapId) ~ - ("Reduce ID" -> shuffleBlockId.reduceId) - case broadcastBlockId: BroadcastBlockId => - "Broadcast ID" -> broadcastBlockId.broadcastId - case broadcastHelperBlockId: BroadcastHelperBlockId => - ("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~ - ("Helper Type" -> broadcastHelperBlockId.hType) - case taskResultBlockId: TaskResultBlockId => - "Task ID" -> taskResultBlockId.taskId - case streamBlockId: StreamBlockId => - ("Stream ID" -> streamBlockId.streamId) ~ - ("Unique ID" -> streamBlockId.uniqueId) - case tempBlockId: TempBlockId => - val uuid = UUIDToJson(tempBlockId.id) - "Temp ID" -> uuid - case testBlockId: TestBlockId => - "Test ID" -> testBlockId.id - } - ("Type" -> blockType) ~ json - } - def blockStatusToJson(blockStatus: BlockStatus): JValue = { val storageLevel = storageLevelToJson(blockStatus.storageLevel) ("Storage Level" -> storageLevel) ~ ("Memory Size" -> blockStatus.memSize) ~ + ("Tachyon Size" -> blockStatus.tachyonSize) ~ ("Disk Size" -> blockStatus.diskSize) } @@ -372,6 +359,8 @@ private[spark] object JsonProtocol { val blockManagerAdded = Utils.getFormattedClassName(SparkListenerBlockManagerAdded) val blockManagerRemoved = Utils.getFormattedClassName(SparkListenerBlockManagerRemoved) val unpersistRDD = Utils.getFormattedClassName(SparkListenerUnpersistRDD) + val applicationStart = Utils.getFormattedClassName(SparkListenerApplicationStart) + val applicationEnd = Utils.getFormattedClassName(SparkListenerApplicationEnd) (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -385,6 +374,8 @@ private[spark] object JsonProtocol { case `blockManagerAdded` => blockManagerAddedFromJson(json) case `blockManagerRemoved` => blockManagerRemovedFromJson(json) case `unpersistRDD` => unpersistRDDFromJson(json) + case `applicationStart` => applicationStartFromJson(json) + case `applicationEnd` => applicationEndFromJson(json) } } @@ -456,6 +447,17 @@ private[spark] object JsonProtocol { SparkListenerUnpersistRDD((json \ "RDD ID").extract[Int]) } + def applicationStartFromJson(json: JValue): SparkListenerApplicationStart = { + val appName = (json \ "App Name").extract[String] + val time = (json \ "Timestamp").extract[Long] + val sparkUser = (json \ "User").extract[String] + SparkListenerApplicationStart(appName, time, sparkUser) + } + + def applicationEndFromJson(json: JValue): SparkListenerApplicationEnd = { + SparkListenerApplicationEnd((json \ "Timestamp").extract[Long]) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | @@ -468,11 +470,13 @@ private[spark] object JsonProtocol { val rddInfo = rddInfoFromJson(json \ "RDD Info") val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long]) val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long]) + val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String]) val emittedTaskSizeWarning = (json \ "Emitted Task Size Warning").extract[Boolean] val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfo) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime + stageInfo.failureReason = failureReason stageInfo.emittedTaskSizeWarning = emittedTaskSizeWarning stageInfo } @@ -498,6 +502,9 @@ private[spark] object JsonProtocol { } def taskMetricsFromJson(json: JValue): TaskMetrics = { + if (json == JNothing) { + return TaskMetrics.empty + } val metrics = new TaskMetrics metrics.hostname = (json \ "Host Name").extract[String] metrics.executorDeserializeTime = (json \ "Executor Deserialize Time").extract[Long] @@ -513,7 +520,7 @@ private[spark] object JsonProtocol { Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson) metrics.updatedBlocks = Utils.jsonOption(json \ "Updated Blocks").map { value => value.extract[List[JValue]].map { block => - val id = blockIdFromJson(block \ "Block ID") + val id = BlockId((block \ "Block ID").extract[String]) val status = blockStatusFromJson(block \ "Status") (id, status) } @@ -587,8 +594,7 @@ private[spark] object JsonProtocol { case `jobSucceeded` => JobSucceeded case `jobFailed` => val exception = exceptionFromJson(json \ "Exception") - val failedStageId = (json \ "Failed Stage ID").extract[Int] - new JobFailed(exception, failedStageId) + new JobFailed(exception) } } @@ -599,11 +605,13 @@ private[spark] object JsonProtocol { val numPartitions = (json \ "Number of Partitions").extract[Int] val numCachedPartitions = (json \ "Number of Cached Partitions").extract[Int] val memSize = (json \ "Memory Size").extract[Long] + val tachyonSize = (json \ "Tachyon Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] val rddInfo = new RDDInfo(rddId, name, numPartitions, storageLevel) rddInfo.numCachedPartitions = numCachedPartitions rddInfo.memSize = memSize + rddInfo.tachyonSize = tachyonSize rddInfo.diskSize = diskSize rddInfo } @@ -611,60 +619,18 @@ private[spark] object JsonProtocol { def storageLevelFromJson(json: JValue): StorageLevel = { val useDisk = (json \ "Use Disk").extract[Boolean] val useMemory = (json \ "Use Memory").extract[Boolean] + val useTachyon = (json \ "Use Tachyon").extract[Boolean] val deserialized = (json \ "Deserialized").extract[Boolean] val replication = (json \ "Replication").extract[Int] - StorageLevel(useDisk, useMemory, deserialized, replication) - } - - def blockIdFromJson(json: JValue): BlockId = { - val rddBlockId = Utils.getFormattedClassName(RDDBlockId) - val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId) - val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId) - val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId) - val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId) - val streamBlockId = Utils.getFormattedClassName(StreamBlockId) - val tempBlockId = Utils.getFormattedClassName(TempBlockId) - val testBlockId = Utils.getFormattedClassName(TestBlockId) - - (json \ "Type").extract[String] match { - case `rddBlockId` => - val rddId = (json \ "RDD ID").extract[Int] - val splitIndex = (json \ "Split Index").extract[Int] - new RDDBlockId(rddId, splitIndex) - case `shuffleBlockId` => - val shuffleId = (json \ "Shuffle ID").extract[Int] - val mapId = (json \ "Map ID").extract[Int] - val reduceId = (json \ "Reduce ID").extract[Int] - new ShuffleBlockId(shuffleId, mapId, reduceId) - case `broadcastBlockId` => - val broadcastId = (json \ "Broadcast ID").extract[Long] - new BroadcastBlockId(broadcastId) - case `broadcastHelperBlockId` => - val broadcastBlockId = - blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId] - val hType = (json \ "Helper Type").extract[String] - new BroadcastHelperBlockId(broadcastBlockId, hType) - case `taskResultBlockId` => - val taskId = (json \ "Task ID").extract[Long] - new TaskResultBlockId(taskId) - case `streamBlockId` => - val streamId = (json \ "Stream ID").extract[Int] - val uniqueId = (json \ "Unique ID").extract[Long] - new StreamBlockId(streamId, uniqueId) - case `tempBlockId` => - val tempId = UUIDFromJson(json \ "Temp ID") - new TempBlockId(tempId) - case `testBlockId` => - val testId = (json \ "Test ID").extract[String] - new TestBlockId(testId) - } + StorageLevel(useDisk, useMemory, useTachyon, deserialized, replication) } def blockStatusFromJson(json: JValue): BlockStatus = { val storageLevel = storageLevelFromJson(json \ "Storage Level") val memorySize = (json \ "Memory Size").extract[Long] val diskSize = (json \ "Disk Size").extract[Long] - BlockStatus(storageLevel, memorySize, diskSize) + val tachyonSize = (json \ "Tachyon Size").extract[Long] + BlockStatus(storageLevel, memorySize, diskSize, tachyonSize) } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 0448919e09161..7ebed5105b9fd 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -62,8 +62,8 @@ private[spark] class MetadataCleaner( private[spark] object MetadataCleanerType extends Enumeration { - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, - SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value + val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER, + SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value type MetadataCleanerType = Value @@ -78,15 +78,16 @@ private[spark] object MetadataCleaner { conf.getInt("spark.cleaner.ttl", -1) } - def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = - { - conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString) - .toInt + def getDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { + conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt } - def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType, - delay: Int) - { + def setDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType, + delay: Int) { conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index a6b39247a54ca..74fa77b68de0b 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -17,13 +17,17 @@ package org.apache.spark.util +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * A tuple of 2 elements. This can be used as an alternative to Scala's Tuple2 when we want to * minimize object allocation. * * @param _1 Element 1 of this MutablePair * @param _2 Element 2 of this MutablePair */ +@DeveloperApi case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T1, @specialized(Int, Long, Double, Char, Boolean/* , AnyRef */) T2] (var _1: T1, var _2: T2) diff --git a/core/src/main/scala/org/apache/spark/util/NextIterator.scala b/core/src/main/scala/org/apache/spark/util/NextIterator.scala index 8266e5e495efc..e5c732a5a559b 100644 --- a/core/src/main/scala/org/apache/spark/util/NextIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/NextIterator.scala @@ -19,7 +19,7 @@ package org.apache.spark.util /** Provides a basic/boilerplate Iterator implementation. */ private[spark] abstract class NextIterator[U] extends Iterator[U] { - + private var gotNext = false private var nextValue: U = _ private var closed = false @@ -34,7 +34,7 @@ private[spark] abstract class NextIterator[U] extends Iterator[U] { * This convention is required because `null` may be a valid value, * and using `Option` seems like it might create unnecessary Some/None * instances, given some iterators might be called in a tight loop. - * + * * @return U, or set 'finished' when done */ protected def getNext(): U diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixRow.scala b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala similarity index 70% rename from mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixRow.scala rename to core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala index 2608a67bfe260..3abc12681fe9a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixRow.scala +++ b/core/src/main/scala/org/apache/spark/util/ParentClassLoader.scala @@ -15,12 +15,18 @@ * limitations under the License. */ -package org.apache.spark.mllib.linalg +package org.apache.spark.util /** - * Class that represents a row of a dense matrix - * - * @param i row index (0 indexing used) - * @param data entries of the row + * A class loader which makes findClass accesible to the child */ -case class MatrixRow(val i: Int, val data: Array[Double]) +private[spark] class ParentClassLoader(parent: ClassLoader) extends ClassLoader(parent) { + + override def findClass(name: String) = { + super.findClass(name) + } + + override def loadClass(name: String): Class[_] = { + super.loadClass(name) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index b955612ca7749..08465575309c6 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -27,9 +27,8 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer -import it.unimi.dsi.fastutil.ints.IntOpenHashSet - import org.apache.spark.Logging +import org.apache.spark.util.collection.OpenHashSet /** * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in @@ -207,7 +206,7 @@ private[spark] object SizeEstimator extends Logging { // Estimate the size of a large array by sampling elements without replacement. var size = 0.0 val rand = new Random(42) - val drawn = new IntOpenHashSet(ARRAY_SAMPLE_SIZE) + val drawn = new OpenHashSet[Int](ARRAY_SAMPLE_SIZE) for (i <- 0 until ARRAY_SAMPLE_SIZE) { var index = 0 do { diff --git a/core/src/main/scala/org/apache/spark/util/StatCounter.scala b/core/src/main/scala/org/apache/spark/util/StatCounter.scala index 732748a7ff82b..d80eed455c427 100644 --- a/core/src/main/scala/org/apache/spark/util/StatCounter.scala +++ b/core/src/main/scala/org/apache/spark/util/StatCounter.scala @@ -62,10 +62,10 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable { if (n == 0) { mu = other.mu m2 = other.m2 - n = other.n + n = other.n maxValue = other.maxValue minValue = other.minValue - } else if (other.n != 0) { + } else if (other.n != 0) { val delta = other.mu - mu if (other.n * 10 < n) { mu = mu + (delta * other.n) / (n + other.n) diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index ddbd084ed7f01..8de75ba9a9c92 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -17,48 +17,54 @@ package org.apache.spark.util +import java.util.Set +import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap -import scala.collection.JavaConversions -import scala.collection.immutable -import scala.collection.mutable.Map +import scala.collection.{JavaConversions, mutable} import org.apache.spark.Logging +private[spark] case class TimeStampedValue[V](value: V, timestamp: Long) + /** * This is a custom implementation of scala.collection.mutable.Map which stores the insertion * timestamp along with each key-value pair. If specified, the timestamp of each pair can be * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular * threshold time can then be removed using the clearOldValues method. This is intended to * be a drop-in replacement of scala.collection.mutable.HashMap. - * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be - * updated when it is accessed + * + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed */ -class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends Map[A, B]() with Logging { - val internalMap = new ConcurrentHashMap[A, (B, Long)]() +private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends mutable.Map[A, B]() with Logging { + + private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() def get(key: A): Option[B] = { val value = internalMap.get(key) if (value != null && updateTimeStampOnGet) { - internalMap.replace(key, value, (value._1, currentTime)) + internalMap.replace(key, value, TimeStampedValue(value.value, currentTime)) } - Option(value).map(_._1) + Option(value).map(_.value) } def iterator: Iterator[(A, B)] = { - val jIterator = internalMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue._1)) + val jIterator = getEntrySet.iterator + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) } - override def + [B1 >: B](kv: (A, B1)): Map[A, B1] = { + def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet + + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { val newMap = new TimeStampedHashMap[A, B1] - newMap.internalMap.putAll(this.internalMap) - newMap.internalMap.put(kv._1, (kv._2, currentTime)) + val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]] + newMap.internalMap.putAll(oldInternalMap) + kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) } newMap } - override def - (key: A): Map[A, B] = { + override def - (key: A): mutable.Map[A, B] = { val newMap = new TimeStampedHashMap[A, B] newMap.internalMap.putAll(this.internalMap) newMap.internalMap.remove(key) @@ -66,17 +72,10 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) } override def += (kv: (A, B)): this.type = { - internalMap.put(kv._1, (kv._2, currentTime)) + kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) } this } - // Should we return previous value directly or as Option ? - def putIfAbsent(key: A, value: B): Option[B] = { - val prev = internalMap.putIfAbsent(key, (value, currentTime)) - if (prev != null) Some(prev._1) else None - } - - override def -= (key: A): this.type = { internalMap.remove(key) this @@ -87,53 +86,65 @@ class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) } override def apply(key: A): B = { - val value = internalMap.get(key) - if (value == null) throw new NoSuchElementException() - value._1 + get(key).getOrElse { throw new NoSuchElementException() } } - override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { + JavaConversions.mapAsScalaConcurrentMap(internalMap) + .map { case (k, TimeStampedValue(v, t)) => (k, v) } + .filter(p) } - override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() + override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() override def size: Int = internalMap.size override def foreach[U](f: ((A, B)) => U) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val kv = (entry.getKey, entry.getValue._1) + val it = getEntrySet.iterator + while(it.hasNext) { + val entry = it.next() + val kv = (entry.getKey, entry.getValue.value) f(kv) } } - def toMap: immutable.Map[A, B] = iterator.toMap + def putIfAbsent(key: A, value: B): Option[B] = { + val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime)) + Option(prev).map(_.value) + } + + def putAll(map: Map[A, B]) { + map.foreach { case (k, v) => update(k, v) } + } + + def toMap: Map[A, B] = iterator.toMap - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`, - * calling the supplied function on each such entry before removing. - */ def clearOldValues(threshTime: Long, f: (A, B) => Unit) { - val iterator = internalMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue._2 < threshTime) { - f(entry.getKey, entry.getValue._1) + val it = getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() + if (entry.getValue.timestamp < threshTime) { + f(entry.getKey, entry.getValue.value) logDebug("Removing key " + entry.getKey) - iterator.remove() + it.remove() } } } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` - */ + /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */ def clearOldValues(threshTime: Long) { clearOldValues(threshTime, (_, _) => ()) } - private def currentTime: Long = System.currentTimeMillis() + private def currentTime: Long = System.currentTimeMillis + // For testing + + def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = { + Option(internalMap.get(key)) + } + + def getTimestamp(key: A): Option[Long] = { + getTimeStampedValue(key).map(_.timestamp) + } } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala index 19bece86b36b4..7cd8f28b12dd6 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConversions import scala.collection.mutable.Set -class TimeStampedHashSet[A] extends Set[A] { +private[spark] class TimeStampedHashSet[A] extends Set[A] { val internalMap = new ConcurrentHashMap[A, Long]() def contains(key: A): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala new file mode 100644 index 0000000000000..b65017d6806c6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.ref.WeakReference +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable + +import org.apache.spark.Logging + +/** + * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. + * + * If the value is garbage collected and the weak reference is null, get() will return a + * non-existent value. These entries are removed from the map periodically (every N inserts), as + * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are + * older than a particular threshold can be removed using the clearOldValues method. + * + * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it + * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap, + * so all operations on this HashMap are thread-safe. + * + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. + */ +private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends mutable.Map[A, B]() with Logging { + + import TimeStampedWeakValueHashMap._ + + private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) + private val insertCount = new AtomicInteger(0) + + /** Return a map consisting only of entries whose values are still strongly reachable. */ + private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null } + + def get(key: A): Option[B] = internalMap.get(key) + + def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator + + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { + val newMap = new TimeStampedWeakValueHashMap[A, B1] + val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]] + newMap.internalMap.putAll(oldMap.toMap) + newMap.internalMap += kv + newMap + } + + override def - (key: A): mutable.Map[A, B] = { + val newMap = new TimeStampedWeakValueHashMap[A, B] + newMap.internalMap.putAll(nonNullReferenceMap.toMap) + newMap.internalMap -= key + newMap + } + + override def += (kv: (A, B)): this.type = { + internalMap += kv + if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) { + clearNullValues() + } + this + } + + override def -= (key: A): this.type = { + internalMap -= key + this + } + + override def update(key: A, value: B) = this += ((key, value)) + + override def apply(key: A): B = internalMap.apply(key) + + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p) + + override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) + + def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) + + def toMap: Map[A, B] = iterator.toMap + + /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ + def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + + /** Remove entries with values that are no longer strongly reachable. */ + def clearNullValues() { + val it = internalMap.getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() + if (entry.getValue.value.get == null) { + logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.") + it.remove() + } + } + } + + // For testing + + def getTimestamp(key: A): Option[Long] = { + internalMap.getTimeStampedValue(key).map(_.timestamp) + } + + def getReference(key: A): Option[WeakReference[B]] = { + internalMap.getTimeStampedValue(key).map(_.value) + } +} + +/** + * Helper methods for converting to and from WeakReferences. + */ +private object TimeStampedWeakValueHashMap { + + // Number of inserts after which entries with null references are removed + val CLEAR_NULL_VALUES_INTERVAL = 100 + + /* Implicit conversion methods to WeakReferences. */ + + implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) + + implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = { + kv match { case (k, v) => (k, toWeakReference(v)) } + } + + implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = { + (kv: (K, WeakReference[V])) => p(kv) + } + + /* Implicit conversion methods from WeakReferences. */ + + implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get + + implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { + v match { + case Some(ref) => Option(fromWeakReference(ref)) + case None => None + } + } + + implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { + kv match { case (k, v) => (k, fromWeakReference(v)) } + } + + implicit def fromWeakReferenceIterator[K, V]( + it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = { + it.map(fromWeakReferenceTuple) + } + + implicit def fromWeakReferenceMap[K, V]( + map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { + mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) + } +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 62ee704d580c2..a3af4e7b91692 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -33,16 +33,20 @@ import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ +import tachyon.client.{TachyonFile,TachyonFS} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} + /** * Various utility methods used by Spark. */ private[spark] object Utils extends Logging { + val osName = System.getProperty("os.name") + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -112,6 +116,21 @@ private[spark] object Utils extends Logging { } } + /** + * Get the ClassLoader which loaded Spark. + */ + def getSparkClassLoader = getClass.getClassLoader + + /** + * Get the Context ClassLoader on this thread or, if not present, the ClassLoader that + * loaded Spark. + * + * This should be used whenever passing a ClassLoader to Class.ForName or finding the currently + * active loader when setting up ClassLoader delegation chains. + */ + def getContextOrSparkClassLoader = + Option(Thread.currentThread().getContextClassLoader).getOrElse(getSparkClassLoader) + /** * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. */ @@ -150,6 +169,7 @@ private[spark] object Utils extends Logging { } private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() + private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() // Register the path to be deleted via shutdown hook def registerShutdownDeleteDir(file: File) { @@ -159,6 +179,14 @@ private[spark] object Utils extends Logging { } } + // Register the tachyon path to be deleted via shutdown hook + def registerShutdownDeleteDir(tachyonfile: TachyonFile) { + val absolutePath = tachyonfile.getPath() + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths += absolutePath + } + } + // Is the path already registered to be deleted via a shutdown hook ? def hasShutdownDeleteDir(file: File): Boolean = { val absolutePath = file.getAbsolutePath() @@ -167,6 +195,14 @@ private[spark] object Utils extends Logging { } } + // Is the path already registered to be deleted via a shutdown hook ? + def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + shutdownDeletePaths.synchronized { + shutdownDeletePaths.contains(absolutePath) + } + } + // Note: if file is child of some registered path, while not equal to it, then return true; // else false. This is to ensure that two shutdown hooks do not try to delete each others // paths - resulting in IOException and incomplete cleanup. @@ -183,6 +219,22 @@ private[spark] object Utils extends Logging { retval } + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in Exception and incomplete cleanup. + def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { + val absolutePath = file.getPath() + val retval = shutdownDeletePaths.synchronized { + shutdownDeletePaths.find { path => + !absolutePath.equals(path) && absolutePath.startsWith(path) + }.isDefined + } + if (retval) { + logInfo("path = " + file + ", already present as root for deletion.") + } + retval + } + /** Create a temporary directory inside the given parent directory */ def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = { var attempts = 0 @@ -461,10 +513,10 @@ private[spark] object Utils extends Logging { private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() def parseHostPort(hostPort: String): (String, Int) = { - { - // Check cache first. - val cached = hostPortParseResults.get(hostPort) - if (cached != null) return cached + // Check cache first. + val cached = hostPortParseResults.get(hostPort) + if (cached != null) { + return cached } val indx: Int = hostPort.lastIndexOf(':') @@ -521,9 +573,10 @@ private[spark] object Utils extends Logging { /** * Delete a file or directory and its contents recursively. + * Don't follow directories if they are symlinks. */ def deleteRecursively(file: File) { - if (file.isDirectory) { + if ((file.isDirectory) && !isSymlink(file)) { for (child <- listFilesSafely(file)) { deleteRecursively(child) } @@ -536,6 +589,49 @@ private[spark] object Utils extends Logging { } } + /** + * Delete a file or directory and its contents recursively. + */ + def deleteRecursively(dir: TachyonFile, client: TachyonFS) { + if (!client.delete(dir.getPath(), true)) { + throw new IOException("Failed to delete the tachyon dir: " + dir) + } + } + + /** + * Check to see if file is a symbolic link. + */ + def isSymlink(file: File): Boolean = { + if (file == null) throw new NullPointerException("File must not be null") + if (osName.startsWith("Windows")) return false + val fileInCanonicalDir = if (file.getParent() == null) { + file + } else { + new File(file.getParentFile().getCanonicalFile(), file.getName()) + } + + if (fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())) { + return false + } else { + return true + } + } + + /** + * Finds all the files in a directory whose last modified time is older than cutoff seconds. + * @param dir must be the path to a directory, or IllegalArgumentException is thrown + * @param cutoff measured in seconds. Files older than this are returned. + */ + def findOldFiles(dir: File, cutoff: Long): Seq[File] = { + val currentTimeMillis = System.currentTimeMillis + if (dir.isDirectory) { + val files = listFilesSafely(dir) + files.filter { file => file.lastModified < (currentTimeMillis - cutoff * 1000) } + } else { + throw new IllegalArgumentException(dir + " is not a directory!") + } + } + /** * Convert a Java memory parameter passed to -Xmx (such as 300m or 1g) to a number of megabytes. */ @@ -898,6 +994,26 @@ private[spark] object Utils extends Logging { count } + /** + * Creates a symlink. Note jdk1.7 has Files.createSymbolicLink but not used here + * for jdk1.6 support. Supports windows by doing copy, everything else uses "ln -sf". + * @param src absolute path to the source + * @param dst relative path for the destination + */ + def symlink(src: File, dst: File) { + if (!src.isAbsolute()) { + throw new IOException("Source must be absolute") + } + if (dst.isAbsolute()) { + throw new IOException("Destination must be relative") + } + val linkCmd = if (osName.startsWith("Windows")) "copy" else "ln -sf" + import scala.sys.process._ + (linkCmd + " " + src.getAbsolutePath() + " " + dst.getPath()) lines_! ProcessLogger(line => + (logInfo(line))) + } + + /** Return the class name of the given object, removing all dollar signs */ def getFormattedClassName(obj: AnyRef) = { obj.getClass.getSimpleName.replace("$", "") @@ -920,4 +1036,11 @@ private[spark] object Utils extends Logging { def getHadoopFileSystem(path: URI): FileSystem = { FileSystem.get(path, SparkHadoopUtil.get.newConfiguration()) } + + /** + * Return a Hadoop FileSystem with the scheme encoded in the given path. + */ + def getHadoopFileSystem(path: String): FileSystem = { + getHadoopFileSystem(new URI(path)) + } } diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala index dc4b8f253f259..1a647fa1c9d84 100644 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ b/core/src/main/scala/org/apache/spark/util/Vector.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.util.random.XORShiftRandom +@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") class Vector(val elements: Array[Double]) extends Serializable { def length = elements.length @@ -135,7 +136,7 @@ object Vector { def ones(length: Int) = Vector(length, _ => 1) /** - * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers + * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers * between 0.0 and 1.0. Optional scala.util.Random number generator can be provided. */ def random(length: Int, random: Random = new XORShiftRandom()) = diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index b8c852b4ff5c7..ad38250ad339f 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -19,7 +19,12 @@ package org.apache.spark.util.collection import java.util.{Arrays, Comparator} +import com.google.common.hash.Hashing + +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * A simple open hash table optimized for the append-only use case, where keys * are never removed, but the value for each key may be changed. * @@ -29,9 +34,9 @@ import java.util.{Arrays, Comparator} * * TODO: Cache the hash values of each key? java.util.HashMap does that. */ -private[spark] -class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, - V)] with Serializable { +@DeveloperApi +class AppendOnlyMap[K, V](initialCapacity: Int = 64) + extends Iterable[(K, V)] with Serializable { require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") require(initialCapacity >= 1, "Invalid initial capacity") @@ -196,11 +201,8 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, /** * Re-hash a value to deal better with hash functions that don't differ in the lower bits. - * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ - private def rehash(h: Int): Int = { - it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) - } + private def rehash(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt() /** Double the table's size and re-hash everything */ protected def growTable() { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index caa06d5b445b4..d615767284c0b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -17,20 +17,21 @@ package org.apache.spark.util.collection -import java.io._ +import java.io.{InputStream, BufferedInputStream, FileInputStream, File, Serializable, EOFException} import java.util.Comparator import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams -import it.unimi.dsi.fastutil.io.FastBufferedInputStream import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, BlockManager} /** + * :: DeveloperApi :: * An append-only map that spills sorted content to disk when there is insufficient space for it * to grow. * @@ -55,8 +56,8 @@ import org.apache.spark.storage.{BlockId, BlockManager} * `spark.shuffle.safetyFraction` specifies an additional margin of safety as a fraction of * this threshold, in case map size estimation is not sufficiently accurate. */ - -private[spark] class ExternalAppendOnlyMap[K, V, C]( +@DeveloperApi +class ExternalAppendOnlyMap[K, V, C]( createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C, @@ -348,7 +349,7 @@ private[spark] class ExternalAppendOnlyMap[K, V, C]( private class DiskMapIterator(file: File, blockId: BlockId, batchSizes: ArrayBuffer[Long]) extends Iterator[(K, C)] { private val fileStream = new FileInputStream(file) - private val bufferedStream = new FastBufferedInputStream(fileStream, fileBufferSize) + private val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize) // An intermediate stream that reads from exactly one batch // This guards against pre-fetching and other arbitrary behavior of higher level streams diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index c26f23d50024a..b8de4ff9aa494 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -19,15 +19,19 @@ package org.apache.spark.util.collection import scala.reflect.ClassTag +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, * but not deletions. This map is about 5X faster than java.util.HashMap, while using much less * space overhead. * * Under the hood, it uses our OpenHashSet implementation. */ +@DeveloperApi private[spark] -class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( +class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( initialCapacity: Int) extends Iterable[(K, V)] with Serializable { diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 148c12e64d2ce..19af4f8cbe428 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -18,6 +18,7 @@ package org.apache.spark.util.collection import scala.reflect._ +import com.google.common.hash.Hashing /** * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never @@ -256,9 +257,8 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag]( /** * Re-hash a value to deal better with hash functions that don't differ in the lower bits. - * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ - private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) + private def hashcode(h: Int): Int = Hashing.murmur3_32().hashInt(h).asInt() private def nextPowerOf2(n: Int): Int = { val highBit = Integer.highestOneBit(n) diff --git a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala index 98569143ee1e3..70f3dd62b9b19 100644 --- a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala @@ -17,9 +17,13 @@ package org.apache.spark.util.random +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * A class with pseudorandom behavior. */ +@DeveloperApi trait Pseudorandom { /** Set random seed. */ def setSeed(seed: Long) diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index 0f1fca4813ba9..37a6b04f5200f 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -22,7 +22,10 @@ import java.util.Random import cern.jet.random.Poisson import cern.jet.random.engine.DRand +import org.apache.spark.annotation.DeveloperApi + /** + * :: DeveloperApi :: * A pseudorandom sampler. It is possible to change the sampled item type. For example, we might * want to add weights for stratified sampling or importance sampling. Should only use * transformations that are tied to the sampler and cannot be applied after sampling. @@ -30,6 +33,7 @@ import cern.jet.random.engine.DRand * @tparam T item type * @tparam U sampled item type */ +@DeveloperApi trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable { /** take a random sample */ @@ -40,6 +44,7 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable } /** + * :: DeveloperApi :: * A sampler based on Bernoulli trials. * * @param lb lower bound of the acceptance range @@ -47,6 +52,7 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable * @param complement whether to use the complement of the range specified, default to false * @tparam T item type */ +@DeveloperApi class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) (implicit random: Random = new XORShiftRandom) extends RandomSampler[T, T] { @@ -67,11 +73,13 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false) } /** + * :: DeveloperApi :: * A sampler based on values drawn from Poisson distribution. * * @param poisson a Poisson random number generator * @tparam T item type */ +@DeveloperApi class PoissonSampler[T](mean: Double) (implicit var poisson: Poisson = new Poisson(mean, new DRand)) extends RandomSampler[T, T] { diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 8a4cdea2fa7b1..7f220383f9f8b 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -25,28 +25,28 @@ import scala.util.hashing.MurmurHash3 import org.apache.spark.util.Utils.timeIt /** - * This class implements a XORShift random number generator algorithm + * This class implements a XORShift random number generator algorithm * Source: * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. * @see Paper * This implementation is approximately 3.5 times faster than * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due - * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class + * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG * for each thread. */ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { - + def this() = this(System.nanoTime) private var seed = XORShiftRandom.hashSeed(init) // we need to just override next - this will be called by nextInt, nextDouble, // nextGaussian, nextLong, etc. - override protected def next(bits: Int): Int = { + override protected def next(bits: Int): Int = { var nextSeed = seed ^ (seed << 21) nextSeed ^= (nextSeed >>> 35) - nextSeed ^= (nextSeed << 4) + nextSeed ^= (nextSeed << 4) seed = nextSeed (nextSeed & ((1L << bits) -1)).asInstanceOf[Int] } @@ -89,7 +89,7 @@ private[spark] object XORShiftRandom { val million = 1e6.toInt val javaRand = new JavaRandom(seed) val xorRand = new XORShiftRandom(seed) - + // this is just to warm up the JIT - we're not timing anything timeIt(1e6.toInt) { javaRand.nextInt() @@ -97,9 +97,9 @@ private[spark] object XORShiftRandom { } val iters = timeIt(numIters)(_) - + /* Return results as a map instead of just printing to screen - in case the user wants to do something with them */ + in case the user wants to do something with them */ Map("javaTime" -> iters {javaRand.nextInt()}, "xorTime" -> iters {xorRand.nextInt()}) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c6b65c7348ae0..8d2e9f1846343 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -17,13 +17,14 @@ package org.apache.spark; -import java.io.File; -import java.io.IOException; -import java.io.Serializable; +import java.io.*; +import java.lang.StringBuilder; import java.util.*; import scala.Tuple2; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; @@ -181,6 +182,14 @@ public void call(String s) { Assert.assertEquals(2, foreachCalls); } + @Test + public void toLocalIterator() { + List correct = Arrays.asList(1, 2, 3, 4); + JavaRDD rdd = sc.parallelize(correct); + List result = Lists.newArrayList(rdd.toLocalIterator()); + Assert.assertTrue(correct.equals(result)); + } + @SuppressWarnings("unchecked") @Test public void lookup() { @@ -190,7 +199,7 @@ public void lookup() { new Tuple2("Oranges", "Citrus") )); Assert.assertEquals(2, categories.lookup("Oranges").size()); - Assert.assertEquals(2, categories.groupByKey().lookup("Oranges").get(0).size()); + Assert.assertEquals(2, Iterables.size(categories.groupByKey().lookup("Oranges").get(0))); } @Test @@ -202,15 +211,15 @@ public Boolean call(Integer x) { return x % 2 == 0; } }; - JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); + JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens - Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds oddsAndEvens = rdd.groupBy(isOdd, 1); Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens - Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds } @SuppressWarnings("unchecked") @@ -225,9 +234,9 @@ public void cogroup() { new Tuple2("Oranges", 2), new Tuple2("Apples", 3) )); - JavaPairRDD, List>> cogrouped = categories.cogroup(prices); - Assert.assertEquals("[Fruit, Citrus]", cogrouped.lookup("Oranges").get(0)._1().toString()); - Assert.assertEquals("[2]", cogrouped.lookup("Oranges").get(0)._2().toString()); + JavaPairRDD, Iterable>> cogrouped = categories.cogroup(prices); + Assert.assertEquals("[Fruit, Citrus]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._1())); + Assert.assertEquals("[2]", Iterables.toString(cogrouped.lookup("Oranges").get(0)._2())); cogrouped.collect(); } @@ -599,6 +608,32 @@ public void textFiles() throws IOException { Assert.assertEquals(expected, readRDD.collect()); } + @Test + public void wholeTextFiles() throws IOException { + byte[] content1 = "spark is easy to use.\n".getBytes(); + byte[] content2 = "spark is also easy to use.\n".getBytes(); + + File tempDir = Files.createTempDir(); + String tempDirName = tempDir.getAbsolutePath(); + DataOutputStream ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00000")); + ds.write(content1); + ds.close(); + ds = new DataOutputStream(new FileOutputStream(tempDirName + "/part-00001")); + ds.write(content2); + ds.close(); + + HashMap container = new HashMap(); + container.put(tempDirName+"/part-00000", new Text(content1).toString()); + container.put(tempDirName+"/part-00001", new Text(content2).toString()); + + JavaPairRDD readRDD = sc.wholeTextFiles(tempDirName, 3); + List> result = readRDD.collect(); + + for (Tuple2 res : result) { + Assert.assertEquals(res._2(), container.get(res._1())); + } + } + @Test public void textFilesCompressed() throws IOException { File tempDir = Files.createTempDir(); diff --git a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala index d2e303d81c4c8..c645e4cbe8132 100644 --- a/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/AkkaUtilsSuite.scala @@ -37,7 +37,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val securityManager = new SecurityManager(conf); val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext System.setProperty("spark.hostPort", hostname + ":" + boundPort) @@ -54,14 +54,14 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { assert(securityManagerBad.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = securityManagerBad) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) } actorSystem.shutdown() @@ -75,7 +75,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val securityManager = new SecurityManager(conf); val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext System.setProperty("spark.hostPort", hostname + ":" + boundPort) @@ -91,9 +91,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { badconf.set("spark.authenticate.secret", "good") val securityManagerBad = new SecurityManager(badconf); - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) @@ -127,7 +127,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val securityManager = new SecurityManager(conf); val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext System.setProperty("spark.hostPort", hostname + ":" + boundPort) @@ -147,7 +147,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = goodconf, securityManager = securityManagerGood) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) @@ -180,7 +180,7 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val securityManager = new SecurityManager(conf); val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, securityManager = securityManager) System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext System.setProperty("spark.hostPort", hostname + ":" + boundPort) @@ -200,12 +200,12 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + intercept[akka.actor.ActorNotFound] { + slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) } actorSystem.shutdown() diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index 96ba3929c1685..c9936256a5b95 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -19,68 +19,297 @@ package org.apache.spark import org.scalatest.FunSuite -class BroadcastSuite extends FunSuite with LocalSparkContext { +import org.apache.spark.storage._ +import org.apache.spark.broadcast.{Broadcast, HttpBroadcast} +import org.apache.spark.storage.BroadcastBlockId +class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { - super.afterEach() - System.clearProperty("spark.broadcast.factory") - } + private val httpConf = broadcastConf("HttpBroadcastFactory") + private val torrentConf = broadcastConf("TorrentBroadcastFactory") test("Using HttpBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing HttpBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing HttpBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } test("Using TorrentBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing TorrentBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing TorrentBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + + test("Unpersisting HttpBroadcast on executors only in local mode") { + testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) + } + + test("Unpersisting HttpBroadcast on executors and driver in local mode") { + testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true) + } + + test("Unpersisting HttpBroadcast on executors only in distributed mode") { + testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false) + } + + test("Unpersisting HttpBroadcast on executors and driver in distributed mode") { + testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only in local mode") { + testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver in local mode") { + testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only in distributed mode") { + testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") { + testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true) + } + /** + * Verify the persistence of state associated with an HttpBroadcast in either local mode or + * local-cluster mode (when distributed = true). + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks and the broadcast file + * are present only on the expected nodes. + */ + private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { + val numSlaves = if (distributed) 2 else 0 + + def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) + + // Verify that the broadcast file is created, and blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") + } + if (distributed) { + // this file is only generated in distributed mode + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + } + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + assert(statuses.size === numSlaves + 1) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. In the latter case, also verify that the broadcast file is deleted on the driver. + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + val possiblyNot = if (removeFromDriver) "" else " not" + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) + if (distributed && removeFromDriver) { + // this file is only generated in distributed mode + assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + "Broadcast file should%s be deleted".format(possiblyNot)) + } + } + + testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks are present only on the + * expected nodes. + */ + private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) { + val numSlaves = if (distributed) 2 else 0 + + def getBlockIds(id: Long) = { + val broadcastBlockId = BroadcastBlockId(id) + val metaBlockId = BroadcastBlockId(id, "meta") + // Assume broadcast value is small enough to fit into 1 piece + val pieceBlockId = BroadcastBlockId(id, "piece0") + if (distributed) { + // the metadata and piece blocks are generated only in distributed mode + Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) + } else { + Seq[BroadcastBlockId](broadcastBlockId) + } + } + + // Verify that blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") + } + } + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + if (blockId.field == "meta") { + // Meta data is only on the driver + assert(statuses.size === 1) + statuses.head match { case (bm, _) => assert(bm.executorId === "") } + } else { + // Other blocks are on both the executors and the driver + assert(statuses.size === numSlaves + 1, + blockId + " has " + statuses.size + " statuses: " + statuses.mkString(",")) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") + } + } + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + val possiblyNot = if (removeFromDriver) "" else " not" + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockId, askSlaves = true) + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) + } + } + + testUnpersistBroadcast(distributed, numSlaves, torrentConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * This test runs in 4 steps: + * + * 1) Create broadcast variable, and verify that all state is persisted on the driver. + * 2) Use the broadcast variable on all executors, and verify that all state is persisted + * on both the driver and the executors. + * 3) Unpersist the broadcast, and verify that all state is removed where they should be. + * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable. + */ + private def testUnpersistBroadcast( + distributed: Boolean, + numSlaves: Int, // used only when distributed = true + broadcastConf: SparkConf, + getBlockIds: Long => Seq[BroadcastBlockId], + afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + removeFromDriver: Boolean) { + + sc = if (distributed) { + new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + } else { + new SparkContext("local", "test", broadcastConf) + } + val blockManagerMaster = sc.env.blockManager.master + val list = List[Int](1, 2, 3, 4) + + // Create broadcast variable + val broadcast = sc.broadcast(list) + val blocks = getBlockIds(broadcast.id) + afterCreation(blocks, blockManagerMaster) + + // Use broadcast variable on all executors + val partitions = 10 + assert(partitions > numSlaves) + val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) + afterUsingBroadcast(blocks, blockManagerMaster) + + // Unpersist broadcast + if (removeFromDriver) { + broadcast.destroy(blocking = true) + } else { + broadcast.unpersist(blocking = true) + } + afterUnpersist(blocks, blockManagerMaster) + + // If the broadcast is removed from driver, all subsequent uses of the broadcast variable + // should throw SparkExceptions. Otherwise, the result should be the same as before. + if (removeFromDriver) { + // Using this variable on the executors crashes them, which hangs the test. + // Instead, crash the driver by directly accessing the broadcast value. + intercept[SparkException] { broadcast.value } + intercept[SparkException] { broadcast.unpersist() } + intercept[SparkException] { broadcast.destroy(blocking = true) } + } else { + val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) + } } + /** Helper method to create a SparkConf that uses the given broadcast factory. */ + private def broadcastConf(factoryName: String): SparkConf = { + val conf = new SparkConf + conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) + conf + } } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala new file mode 100644 index 0000000000000..e50981cf6fb20 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -0,0 +1,415 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.lang.ref.WeakReference + +import scala.collection.mutable.{HashSet, SynchronizedSet} +import scala.util.Random + +import org.scalatest.{BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Eventually +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId} + +class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + implicit val defaultTimeout = timeout(10000 millis) + val conf = new SparkConf() + .setMaster("local[2]") + .setAppName("ContextCleanerSuite") + .set("spark.cleaner.referenceTracking.blocking", "true") + + before { + sc = new SparkContext(conf) + } + + after { + if (sc != null) { + sc.stop() + sc = null + } + } + + + test("cleanup RDD") { + val rdd = newRDD.persist() + val collected = rdd.collect().toList + val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + + // Explicit cleanup + cleaner.doCleanupRDD(rdd.id, blocking = true) + tester.assertCleanup() + + // Verify that RDDs can be re-executed after cleaning up + assert(rdd.collect().toList === collected) + } + + test("cleanup shuffle") { + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies + val collected = rdd.collect().toList + val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) + + // Explicit cleanup + shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true)) + tester.assertCleanup() + + // Verify that shuffles can be re-executed after cleaning up + assert(rdd.collect().toList === collected) + } + + test("cleanup broadcast") { + val broadcast = newBroadcast + val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + + // Explicit cleanup + cleaner.doCleanupBroadcast(broadcast.id, blocking = true) + tester.assertCleanup() + } + + test("automatically cleanup RDD") { + var rdd = newRDD.persist() + rdd.count() + + // Test that GC does not cause RDD cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC causes RDD cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) + rdd = null // Make RDD out of scope + runGC() + postGCTester.assertCleanup() + } + + test("automatically cleanup shuffle") { + var rdd = newShuffleRDD + rdd.count() + + // Test that GC does not cause shuffle cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC causes shuffle cleanup after dereferencing the RDD + val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope + runGC() + postGCTester.assertCleanup() + } + + test("automatically cleanup broadcast") { + var broadcast = newBroadcast + + // Test that GC does not cause broadcast cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC causes broadcast cleanup after dereferencing the broadcast variable + val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + broadcast = null // Make broadcast variable out of scope + runGC() + postGCTester.assertCleanup() + } + + test("automatically cleanup RDD + shuffle + broadcast") { + val numRdds = 100 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddIds = sc.persistentRdds.keys.toSeq + val shuffleIds = 0 until sc.newShuffleId + val broadcastIds = 0L until numBroadcasts + + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() + runGC() + postGCTester.assertCleanup() + } + + test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { + sc.stop() + + val conf2 = new SparkConf() + .setMaster("local-cluster[2, 1, 512]") + .setAppName("ContextCleanerSuite") + .set("spark.cleaner.referenceTracking.blocking", "true") + sc = new SparkContext(conf2) + + val numRdds = 10 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddIds = sc.persistentRdds.keys.toSeq + val shuffleIds = 0 until sc.newShuffleId + val broadcastIds = 0L until numBroadcasts + + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) + } + + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() + runGC() + postGCTester.assertCleanup() + } + + //------ Helper functions ------ + + def newRDD = sc.makeRDD(1 to 10) + def newPairRDD = newRDD.map(_ -> 1) + def newShuffleRDD = newPairRDD.reduceByKey(_ + _) + def newBroadcast = sc.broadcast(1 to 100) + def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { + def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { + rdd.dependencies ++ rdd.dependencies.flatMap { dep => + getAllDependencies(dep.rdd) + } + } + val rdd = newShuffleRDD + + // Get all the shuffle dependencies + val shuffleDeps = getAllDependencies(rdd) + .filter(_.isInstanceOf[ShuffleDependency[_, _]]) + .map(_.asInstanceOf[ShuffleDependency[_, _]]) + (rdd, shuffleDeps) + } + + def randomRdd = { + val rdd: RDD[_] = Random.nextInt(3) match { + case 0 => newRDD + case 1 => newShuffleRDD + case 2 => newPairRDD.join(newPairRDD) + } + if (Random.nextBoolean()) rdd.persist() + rdd.count() + rdd + } + + def randomBroadcast = { + sc.broadcast(Random.nextInt(Int.MaxValue)) + } + + /** Run GC and make sure it actually has run */ + def runGC() { + val weakRef = new WeakReference(new Object()) + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + // Wait until a weak reference object has been GCed + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + Thread.sleep(200) + } + } + + def cleaner = sc.cleaner.get +} + + +/** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ +class CleanerTester( + sc: SparkContext, + rddIds: Seq[Int] = Seq.empty, + shuffleIds: Seq[Int] = Seq.empty, + broadcastIds: Seq[Long] = Seq.empty) + extends Logging { + + val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds + val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds + val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds + val isDistributed = !sc.isLocal + + val cleanerListener = new CleanerListener { + def rddCleaned(rddId: Int): Unit = { + toBeCleanedRDDIds -= rddId + logInfo("RDD "+ rddId + " cleaned") + } + + def shuffleCleaned(shuffleId: Int): Unit = { + toBeCleanedShuffleIds -= shuffleId + logInfo("Shuffle " + shuffleId + " cleaned") + } + + def broadcastCleaned(broadcastId: Long): Unit = { + toBeCleanedBroadcstIds -= broadcastId + logInfo("Broadcast" + broadcastId + " cleaned") + } + } + + val MAX_VALIDATION_ATTEMPTS = 10 + val VALIDATION_ATTEMPT_INTERVAL = 100 + + logInfo("Attempting to validate before cleanup:\n" + uncleanedResourcesToString) + preCleanupValidate() + sc.cleaner.get.attachListener(cleanerListener) + + /** Assert that all the stuff has been cleaned up */ + def assertCleanup()(implicit waitTimeout: Eventually.Timeout) { + try { + eventually(waitTimeout, interval(100 millis)) { + assert(isAllCleanedUp) + } + postCleanupValidate() + } finally { + logInfo("Resources left from cleaning up:\n" + uncleanedResourcesToString) + } + } + + /** Verify that RDDs, shuffles, etc. occupy resources */ + private def preCleanupValidate() { + assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup") + + // Verify the RDDs have been persisted and blocks are present + rddIds.foreach { rddId => + assert( + sc.persistentRdds.contains(rddId), + "RDD " + rddId + " have not been persisted, cannot start cleaner test" + ) + + assert( + !getRDDBlocks(rddId).isEmpty, + "Blocks of RDD " + rddId + " cannot be found in block manager, " + + "cannot start cleaner test" + ) + } + + // Verify the shuffle ids are registered and blocks are present + shuffleIds.foreach { shuffleId => + assert( + mapOutputTrackerMaster.containsShuffle(shuffleId), + "Shuffle " + shuffleId + " have not been registered, cannot start cleaner test" + ) + + assert( + !getShuffleBlocks(shuffleId).isEmpty, + "Blocks of shuffle " + shuffleId + " cannot be found in block manager, " + + "cannot start cleaner test" + ) + } + + // Verify that the broadcast blocks are present + broadcastIds.foreach { broadcastId => + assert( + !getBroadcastBlocks(broadcastId).isEmpty, + "Blocks of broadcast " + broadcastId + "cannot be found in block manager, " + + "cannot start cleaner test" + ) + } + } + + /** + * Verify that RDDs, shuffles, etc. do not occupy resources. Tests multiple times as there is + * as there is not guarantee on how long it will take clean up the resources. + */ + private def postCleanupValidate() { + // Verify the RDDs have been persisted and blocks are present + rddIds.foreach { rddId => + assert( + !sc.persistentRdds.contains(rddId), + "RDD " + rddId + " was not cleared from sc.persistentRdds" + ) + + assert( + getRDDBlocks(rddId).isEmpty, + "Blocks of RDD " + rddId + " were not cleared from block manager" + ) + } + + // Verify the shuffle ids are registered and blocks are present + shuffleIds.foreach { shuffleId => + assert( + !mapOutputTrackerMaster.containsShuffle(shuffleId), + "Shuffle " + shuffleId + " was not deregistered from map output tracker" + ) + + assert( + getShuffleBlocks(shuffleId).isEmpty, + "Blocks of shuffle " + shuffleId + " were not cleared from block manager" + ) + } + + // Verify that the broadcast blocks are present + broadcastIds.foreach { broadcastId => + assert( + getBroadcastBlocks(broadcastId).isEmpty, + "Blocks of broadcast " + broadcastId + " were not cleared from block manager" + ) + } + } + + private def uncleanedResourcesToString = { + s""" + |\tRDDs = ${toBeCleanedRDDIds.toSeq.sorted.mkString("[", ", ", "]")} + |\tShuffles = ${toBeCleanedShuffleIds.toSeq.sorted.mkString("[", ", ", "]")} + |\tBroadcasts = ${toBeCleanedBroadcstIds.toSeq.sorted.mkString("[", ", ", "]")} + """.stripMargin + } + + private def isAllCleanedUp = + toBeCleanedRDDIds.isEmpty && + toBeCleanedShuffleIds.isEmpty && + toBeCleanedBroadcstIds.isEmpty + + private def getRDDBlocks(rddId: Int): Seq[BlockId] = { + blockManager.master.getMatchingBlockIds( _ match { + case RDDBlockId(`rddId`, _) => true + case _ => false + }, askSlaves = true) + } + + private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = { + blockManager.master.getMatchingBlockIds( _ match { + case ShuffleBlockId(`shuffleId`, _, _) => true + case _ => false + }, askSlaves = true) + } + + private def getBroadcastBlocks(broadcastId: Long): Seq[BlockId] = { + blockManager.master.getMatchingBlockIds( _ match { + case BroadcastBlockId(`broadcastId`, _) => true + case _ => false + }, askSlaves = true) + } + + private def blockManager = sc.env.blockManager + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] +} diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 9cbdfc54a3dc8..7f59bdcce4cc7 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -39,7 +39,7 @@ class DriverSuite extends FunSuite with Timeouts { failAfter(60 seconds) { Utils.executeAndGetOutput( Seq("./bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), - new File(sparkHome), + new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) } } diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index f3fb64d87a2fd..12dbebcb28644 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -72,7 +72,7 @@ class FailureSuite extends FunSuite with LocalSparkContext { throw new Exception("Intentional task failure") } } - (k, v(0) * v(0)) + (k, v.head * v.head) }.collect() FailureSuiteState.synchronized { assert(FailureSuiteState.tasksRun === 4) @@ -137,5 +137,3 @@ class FailureSuite extends FunSuite with LocalSparkContext { // TODO: Need to add tests with shuffle fetch failures. } - - diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index aee9ab9091dac..d651fbbac4e97 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -45,7 +45,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val pw = new PrintWriter(textFile) pw.println("100") pw.close() - + val jarFile = new File(tmpDir, "test.jar") val jarStream = new FileOutputStream(jarFile) val jar = new JarOutputStream(jarStream, new java.util.jar.Manifest()) @@ -53,7 +53,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext { val jarEntry = new JarEntry(textFile.getName) jar.putNextEntry(jarEntry) - + val in = new FileInputStream(textFile) val buffer = new Array[Byte](10240) var nRead = 0 diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 01af94077144a..b9b668d3cc62a 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -106,7 +106,7 @@ class FileSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val tempDir = Files.createTempDir() val outputDir = new File(tempDir, "output").getAbsolutePath - val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x)) + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), "a" * x)) nums.saveAsSequenceFile(outputDir) // Try reading the output back as a SequenceFile val output = sc.sequenceFile[IntWritable, Text](outputDir) diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index a5bd72eb0a122..6b2571cd9295e 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -57,12 +57,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.stop() } - test("master register and fetch") { + test("master register shuffle and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) + assert(tracker.containsShuffle(10)) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) val size1000 = MapOutputTracker.decompressSize(compressedSize1000) @@ -77,7 +78,25 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.stop() } - test("master register and unregister and fetch") { + test("master register and unregister shuffle") { + val actorSystem = ActorSystem("test") + val tracker = new MapOutputTrackerMaster(conf) + tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.registerShuffle(10, 2) + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val compressedSize10000 = MapOutputTracker.compressSize(10000L) + tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0), + Array(compressedSize1000, compressedSize10000))) + tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0), + Array(compressedSize10000, compressedSize1000))) + assert(tracker.containsShuffle(10)) + assert(tracker.getServerStatuses(10, 0).nonEmpty) + tracker.unregisterShuffle(10) + assert(!tracker.containsShuffle(10)) + assert(tracker.getServerStatuses(10, 0).isEmpty) + } + + test("master register shuffle and unregister map output and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = @@ -114,7 +133,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) - val slaveTracker = new MapOutputTracker(conf) + val slaveTracker = new MapOutputTrackerWorker(conf) val selection = slaveSystem.actorSelection( s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") val timeout = AkkaUtils.lookupTimeout(conf) diff --git a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala index 6e7fd55fa4bb1..867b28cc0d971 100644 --- a/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/PipedRDDSuite.scala @@ -17,8 +17,11 @@ package org.apache.spark -import org.scalatest.FunSuite +import java.io.File + +import com.google.common.io.Files +import org.scalatest.FunSuite import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition} import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit} @@ -82,7 +85,7 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { (f: String => Unit) => { bl.value.map(f(_)); f("\u0001") }, - (i: Tuple2[String, Seq[String]], f: String => Unit) => { + (i: Tuple2[String, Iterable[String]], f: String => Unit) => { for (e <- i._2) { f(e + "_") } @@ -126,6 +129,29 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext { } } + test("basic pipe with separate working directory") { + if (testCommandAvailable("cat")) { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val piped = nums.pipe(Seq("cat"), separateWorkingDir = true) + val c = piped.collect() + assert(c.size === 4) + assert(c(0) === "1") + assert(c(1) === "2") + assert(c(2) === "3") + assert(c(3) === "4") + val pipedPwd = nums.pipe(Seq("pwd"), separateWorkingDir = true) + val collectPwd = pipedPwd.collect() + assert(collectPwd(0).contains("tasks/")) + val pipedLs = nums.pipe(Seq("ls"), separateWorkingDir = true).collect() + // make sure symlinks were created + assert(pipedLs.length > 0) + // clean up top level tasks directory + new File("tasks").delete() + } else { + assert(true) + } + } + test("test pipe exports map_input_file") { testExportInputFile("map_input_file") } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index b543471a5d35b..94fba102865b3 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -51,6 +51,14 @@ class SparkContextSchedulerCreationSuite } } + test("local-*") { + val sched = createTaskScheduler("local[*]") + sched.backend match { + case s: LocalBackend => assert(s.totalCores === Runtime.getRuntime.availableProcessors()) + case _ => fail() + } + } + test("local-n") { val sched = createTaskScheduler("local[5]") assert(sched.maxTaskFailures === 1) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 0b5ed6d77034b..5e538d6fab2a1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -45,4 +45,4 @@ class WorkerWatcherSuite extends FunSuite { actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) assert(!actorRef.underlyingActor.isShutDown) } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala new file mode 100644 index 0000000000000..e2050e95a1b88 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorURLClassLoaderSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import java.net.URLClassLoader + +import org.scalatest.FunSuite + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, TestUtils} +import org.apache.spark.util.Utils + +class ExecutorURLClassLoaderSuite extends FunSuite { + + val childClassNames = List("FakeClass1", "FakeClass2") + val parentClassNames = List("FakeClass1", "FakeClass2", "FakeClass3") + val urls = List(TestUtils.createJarWithClasses(childClassNames, "1")).toArray + val urls2 = List(TestUtils.createJarWithClasses(parentClassNames, "2")).toArray + + test("child first") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val fakeClass = classLoader.loadClass("FakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + } + + test("parent first") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorURLClassLoader(urls, parentLoader) + val fakeClass = classLoader.loadClass("FakeClass1").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "2") + } + + test("child first can fall back") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + val fakeClass = classLoader.loadClass("FakeClass3").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "2") + } + + test("child first can fail") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ChildExecutorURLClassLoader(urls, parentLoader) + intercept[java.lang.ClassNotFoundException] { + classLoader.loadClass("FakeClassDoesNotExist").newInstance() + } + } + + test("driver sets context class loader in local mode") { + // Test the case where the driver program sets a context classloader and then runs a job + // in local mode. This is what happens when ./spark-submit is called with "local" as the + // master. + val original = Thread.currentThread().getContextClassLoader + + val className = "ClassForDriverTest" + val jar = TestUtils.createJarWithClasses(Seq(className)) + val contextLoader = new URLClassLoader(Array(jar), Utils.getContextOrSparkClassLoader) + Thread.currentThread().setContextClassLoader(contextLoader) + + val sc = new SparkContext("local", "driverLoaderTest") + + try { + sc.makeRDD(1 to 5, 2).mapPartitions { x => + val loader = Thread.currentThread().getContextClassLoader + Class.forName(className, true, loader).newInstance() + Seq().iterator + }.count() + } + catch { + case e: SparkException if e.getMessage.contains("ClassNotFoundException") => + fail("Local executor could not find class", e) + case t: Throwable => fail("Unexpected exception ", t) + } + + sc.stop() + Thread.currentThread().setContextClassLoader(original) + } +} diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala new file mode 100644 index 0000000000000..33d6de9a76405 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.input + +import java.io.DataOutputStream +import java.io.File +import java.io.FileOutputStream + +import scala.collection.immutable.IndexedSeq + +import com.google.common.io.Files + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.hadoop.io.Text + +import org.apache.spark.SparkContext + +/** + * Tests the correctness of + * [[org.apache.spark.input.WholeTextFileRecordReader WholeTextFileRecordReader]]. A temporary + * directory is created as fake input. Temporal storage would be deleted in the end. + */ +class WholeTextFileRecordReaderSuite extends FunSuite with BeforeAndAfterAll { + private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + + // Set the block size of local file system to test whether files are split right or not. + sc.hadoopConfiguration.setLong("fs.local.block.size", 32) + } + + override def afterAll() { + sc.stop() + } + + private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte]) = { + val out = new DataOutputStream(new FileOutputStream(s"${inputDir.toString}/$fileName")) + out.write(contents, 0, contents.length) + out.close() + } + + /** + * This code will test the behaviors of WholeTextFileRecordReader based on local disk. There are + * three aspects to check: + * 1) Whether all files are read; + * 2) Whether paths are read correctly; + * 3) Does the contents be the same. + */ + test("Correctness of WholeTextFileRecordReader.") { + + val dir = Files.createTempDir() + println(s"Local disk address is ${dir.toString}.") + + WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) => + createNativeFile(dir, filename, contents) + } + + val res = sc.wholeTextFiles(dir.toString, 3).collect() + + assert(res.size === WholeTextFileRecordReaderSuite.fileNames.size, + "Number of files read out does not fit with the actual value.") + + for ((filename, contents) <- res) { + val shortName = filename.split('/').last + assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName), + s"Missing file name $filename.") + assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString, + s"file $filename contents can not match.") + } + + dir.delete() + } +} + +/** + * Files to be tested are defined here. + */ +object WholeTextFileRecordReaderSuite { + private val testWords: IndexedSeq[Byte] = "Spark is easy to use.\n".map(_.toByte) + + private val fileNames = Array("part-00000", "part-00001", "part-00002") + private val fileLengths = Array(10, 100, 1000) + + private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) => + filename -> Stream.continually(testWords.toList.toStream).flatten.take(upperBound).toArray + }.toMap +} diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index f9e994b13dfbc..8f3e6bd21b752 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -225,11 +225,12 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext { val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w'))) val joined = rdd1.groupWith(rdd2).collect() assert(joined.size === 4) - assert(joined.toSet === Set( - (1, (ArrayBuffer(1, 2), ArrayBuffer('x'))), - (2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))), - (3, (ArrayBuffer(1), ArrayBuffer())), - (4, (ArrayBuffer(), ArrayBuffer('w'))) + val joinedSet = joined.map(x => (x._1, (x._2._1.toList, x._2._2.toList))).toSet + assert(joinedSet === Set( + (1, (List(1, 2), List('x'))), + (2, (List(1), List('y', 'z'))), + (3, (List(1), List())), + (4, (List(), List('w'))) )) } @@ -447,4 +448,3 @@ class ConfigTestFormat() extends FakeFormat() with Configurable { super.getRecordWriter(p1) } } - diff --git a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala index a4381a8b974df..4df36558b6d4b 100644 --- a/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/ParallelCollectionSplitSuite.scala @@ -34,14 +34,14 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(1).mkString(",") === "2") assert(slices(2).mkString(",") === "3") } - + test("one slice") { val data = Array(1, 2, 3) val slices = ParallelCollectionRDD.slice(data, 1) assert(slices.size === 1) assert(slices(0).mkString(",") === "1,2,3") } - + test("equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9) val slices = ParallelCollectionRDD.slice(data, 3) @@ -50,7 +50,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(1).mkString(",") === "4,5,6") assert(slices(2).mkString(",") === "7,8,9") } - + test("non-equal slices") { val data = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10) val slices = ParallelCollectionRDD.slice(data, 3) @@ -77,14 +77,14 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices(1).mkString(",") === (33 to 66).mkString(",")) assert(slices(2).mkString(",") === (67 to 100).mkString(",")) } - + test("empty data") { val data = new Array[Int](0) val slices = ParallelCollectionRDD.slice(data, 5) assert(slices.size === 5) for (slice <- slices) assert(slice.size === 0) } - + test("zero slices") { val data = Array(1, 2, 3) intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, 0) } @@ -94,7 +94,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { val data = Array(1, 2, 3) intercept[IllegalArgumentException] { ParallelCollectionRDD.slice(data, -5) } } - + test("exclusive ranges sliced into ranges") { val data = 1 until 100 val slices = ParallelCollectionRDD.slice(data, 3) @@ -102,7 +102,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[Range])) } - + test("inclusive ranges sliced into ranges") { val data = 1 to 100 val slices = ParallelCollectionRDD.slice(data, 3) @@ -124,7 +124,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(range.step === 1, "slice " + i + " step") } } - + test("random array tests") { val gen = for { d <- arbitrary[List[Int]] @@ -141,7 +141,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { } check(prop) } - + test("random exclusive range tests") { val gen = for { a <- Gen.choose(-100, 100) @@ -177,7 +177,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { } check(prop) } - + test("exclusive ranges of longs") { val data = 1L until 100L val slices = ParallelCollectionRDD.slice(data, 3) @@ -185,7 +185,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } - + test("inclusive ranges of longs") { val data = 1L to 100L val slices = ParallelCollectionRDD.slice(data, 3) @@ -193,7 +193,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 100) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } - + test("exclusive ranges of doubles") { val data = 1.0 until 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) @@ -201,7 +201,7 @@ class ParallelCollectionSplitSuite extends FunSuite with Checkers { assert(slices.map(_.size).reduceLeft(_+_) === 99) assert(slices.forall(_.isInstanceOf[NumericRange[_]])) } - + test("inclusive ranges of doubles") { val data = 1.0 to 100.0 by 1.0 val slices = ParallelCollectionRDD.slice(data, 3) diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index d6b5fdc7984b4..1901330d8b188 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -33,6 +33,7 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) + assert(nums.toLocalIterator.toList === List(1, 2, 3, 4)) val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2) assert(dups.distinct().count() === 4) assert(dups.distinct.count === 4) // Can distinct and count be called without parentheses? @@ -273,37 +274,42 @@ class RDDSuite extends FunSuite with SharedSparkContext { test("coalesced RDDs with locality, large scale (10K partitions)") { // large scale experiment import collection.mutable - val rnd = scala.util.Random val partitions = 10000 val numMachines = 50 val machines = mutable.ListBuffer[String]() - (1 to numMachines).foreach(machines += "m"+_) - - val blocks = (1 to partitions).map(i => - { (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) } ) - - val data2 = sc.makeRDD(blocks) - val coalesced2 = data2.coalesce(numMachines*2) - - // test that you get over 90% locality in each group - val minLocality = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.0)((perc, loc) => math.min(perc,loc)) - assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.0).toInt + "%") - - // test that the groups are load balanced with 100 +/- 20 elements in each - val maxImbalance = coalesced2.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) - .foldLeft(0)((dev, curr) => math.max(math.abs(100-curr),dev)) - assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) - - val data3 = sc.makeRDD(blocks).map(i => i*2) // derived RDD to test *current* pref locs - val coalesced3 = data3.coalesce(numMachines*2) - val minLocality2 = coalesced3.partitions - .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.0)((perc, loc) => math.min(perc,loc)) - assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + - (minLocality2*100.0).toInt + "%") + (1 to numMachines).foreach(machines += "m" + _) + val rnd = scala.util.Random + for (seed <- 1 to 5) { + rnd.setSeed(seed) + + val blocks = (1 to partitions).map { i => + (i, Array.fill(3)(machines(rnd.nextInt(machines.size))).toList) + } + + val data2 = sc.makeRDD(blocks) + val coalesced2 = data2.coalesce(numMachines * 2) + + // test that you get over 90% locality in each group + val minLocality = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.0)((perc, loc) => math.min(perc, loc)) + assert(minLocality >= 0.90, "Expected 90% locality but got " + + (minLocality * 100.0).toInt + "%") + + // test that the groups are load balanced with 100 +/- 20 elements in each + val maxImbalance = coalesced2.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].parents.size) + .foldLeft(0)((dev, curr) => math.max(math.abs(100 - curr), dev)) + assert(maxImbalance <= 20, "Expected 100 +/- 20 per partition, but got " + maxImbalance) + + val data3 = sc.makeRDD(blocks).map(i => i * 2) // derived RDD to test *current* pref locs + val coalesced3 = data3.coalesce(numMachines * 2) + val minLocality2 = coalesced3.partitions + .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) + .foldLeft(1.0)((perc, loc) => math.min(perc, loc)) + assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + + (minLocality2 * 100.0).toInt + "%") + } } test("zipped RDDs") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index c97543f57d8f3..db4df1d1212ff 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import scala.Tuple2 -import scala.collection.mutable.{HashMap, Map} +import scala.collection.mutable.{HashSet, HashMap, Map} import org.scalatest.{BeforeAndAfter, FunSuite} @@ -43,6 +43,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val conf = new SparkConf /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() + + /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */ + val cancelledStages = new HashSet[Int]() + val taskScheduler = new TaskScheduler() { override def rootPool: Pool = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE @@ -53,11 +57,28 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch) taskSets += taskSet } - override def cancelTasks(stageId: Int) {} + override def cancelTasks(stageId: Int) { + cancelledStages += stageId + } override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 } + /** Length of time to wait while draining listener events. */ + val WAIT_TIMEOUT_MILLIS = 10000 + val sparkListener = new SparkListener() { + val successfulStages = new HashSet[Int]() + val failedStages = new HashSet[Int]() + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { + val stageInfo = stageCompleted.stageInfo + if (stageInfo.failureReason.isEmpty) { + successfulStages += stageInfo.stageId + } else { + failedStages += stageInfo.stageId + } + } + } + var mapOutputTracker: MapOutputTrackerMaster = null var scheduler: DAGScheduler = null @@ -83,14 +104,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont /** The list of results that DAGScheduler has collected. */ val results = new HashMap[Int, Any]() var failure: Exception = _ - val listener = new JobListener() { + val jobListener = new JobListener() { override def taskSucceeded(index: Int, result: Any) = results.put(index, result) override def jobFailed(exception: Exception) = { failure = exception } } before { sc = new SparkContext("local", "DAGSchedulerSuite") + sparkListener.successfulStages.clear() + sparkListener.failedStages.clear() + sc.addSparkListener(sparkListener) taskSets.clear() + cancelledStages.clear() cacheLocations.clear() results.clear() mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -174,15 +199,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont } } - /** Sends the rdd to the scheduler for scheduling. */ + /** Sends the rdd to the scheduler for scheduling and returns the job id. */ private def submit( rdd: RDD[_], partitions: Array[Int], func: (TaskContext, Iterator[_]) => _ = jobComputeFunc, allowLocal: Boolean = false, - listener: JobListener = listener) { + listener: JobListener = jobListener): Int = { val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, func, partitions, allowLocal, null, listener)) + return jobId } /** Sends TaskSetFailed to the scheduler. */ @@ -190,6 +216,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont runEvent(TaskSetFailed(taskSet, message)) } + /** Sends JobCancelled to the DAG scheduler. */ + private def cancel(jobId: Int) { + runEvent(JobCancelled(jobId)) + } + test("zero split job") { val rdd = makeRdd(0, Nil) var numResults = 0 @@ -218,7 +249,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont override def toString = "DAGSchedulerSuite Local RDD" } val jobId = scheduler.nextJobId.getAndIncrement() - runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) + runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, jobListener)) assert(results === Map(0 -> 42)) assertDataStructuresEmpty } @@ -248,7 +279,21 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") - assert(failure.getMessage === "Job aborted: some failure") + assert(failure.getMessage === "Job aborted due to stage failure: some failure") + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.size === 1) + assertDataStructuresEmpty + } + + test("trivial job cancellation") { + val rdd = makeRdd(1, Nil) + val jobId = submit(rdd, Array(0)) + cancel(jobId) + assert(failure.getMessage === s"Job $jobId cancelled ") + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(0)) + assert(sparkListener.failedStages.size === 1) assertDataStructuresEmpty } @@ -323,6 +368,82 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assertDataStructuresEmpty } + test("run shuffle with map stage failure") { + val shuffleMapRdd = makeRdd(2, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) + val reduceRdd = makeRdd(2, List(shuffleDep)) + submit(reduceRdd, Array(0, 1)) + + // Fail the map stage. This should cause the entire job to fail. + val stageFailureMessage = "Exception failure in map stage" + failed(taskSets(0), stageFailureMessage) + assert(failure.getMessage === s"Job aborted due to stage failure: $stageFailureMessage") + + // Listener bus should get told about the map stage failing, but not the reduce stage + // (since the reduce stage hasn't been started yet). + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.failedStages.contains(1)) + assert(sparkListener.failedStages.size === 1) + + assertDataStructuresEmpty + } + + /** + * Makes sure that failures of stage used by multiple jobs are correctly handled. + * + * This test creates the following dependency graph: + * + * shuffleMapRdd1 shuffleMapRDD2 + * | \ | + * | \ | + * | \ | + * | \ | + * reduceRdd1 reduceRdd2 + * + * We start both shuffleMapRdds and then fail shuffleMapRdd1. As a result, the job listeners for + * reduceRdd1 and reduceRdd2 should both be informed that the job failed. shuffleMapRDD2 should + * also be cancelled, because it is only used by reduceRdd2 and reduceRdd2 cannot complete + * without shuffleMapRdd1. + */ + test("failure of stage used by two jobs") { + val shuffleMapRdd1 = makeRdd(2, Nil) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, null) + val shuffleMapRdd2 = makeRdd(2, Nil) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, null) + + val reduceRdd1 = makeRdd(2, List(shuffleDep1)) + val reduceRdd2 = makeRdd(2, List(shuffleDep1, shuffleDep2)) + + // We need to make our own listeners for this test, since by default submit uses the same + // listener for all jobs, and here we want to capture the failure for each job separately. + class FailureRecordingJobListener() extends JobListener { + var failureMessage: String = _ + override def taskSucceeded(index: Int, result: Any) {} + override def jobFailed(exception: Exception) = { failureMessage = exception.getMessage } + } + val listener1 = new FailureRecordingJobListener() + val listener2 = new FailureRecordingJobListener() + + submit(reduceRdd1, Array(0, 1), listener=listener1) + submit(reduceRdd2, Array(0, 1), listener=listener2) + + val stageFailureMessage = "Exception failure in map stage" + failed(taskSets(0), stageFailureMessage) + + assert(cancelledStages.contains(1)) + + // Make sure the listeners got told about both failed stages. + assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + assert(sparkListener.successfulStages.isEmpty) + assert(sparkListener.failedStages.contains(1)) + assert(sparkListener.failedStages.contains(3)) + assert(sparkListener.failedStages.size === 2) + + assert(listener1.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") + assert(listener2.failureMessage === s"Job aborted due to stage failure: $stageFailureMessage") + assertDataStructuresEmpty + } + test("run trivial shuffle with out-of-band failure and retry") { val shuffleMapRdd = makeRdd(2, Nil) val shuffleDep = new ShuffleDependency(shuffleMapRdd, null) @@ -428,7 +549,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assert(scheduler.pendingTasks.isEmpty) assert(scheduler.activeJobs.isEmpty) assert(scheduler.failedStages.isEmpty) - assert(scheduler.stageIdToActiveJob.isEmpty) + assert(scheduler.jobIdToActiveJob.isEmpty) assert(scheduler.jobIdToStageIds.isEmpty) assert(scheduler.stageIdToJobIds.isEmpty) assert(scheduler.stageIdToStage.isEmpty) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 7c843772bc2e0..4cdccdda6f72e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.scheduler +import java.util.concurrent.Semaphore + import scala.collection.mutable import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite} @@ -72,6 +74,49 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc } } + test("bus.stop() waits for the event queue to completely drain") { + @volatile var drained = false + + // Tells the listener to stop blocking + val listenerWait = new Semaphore(1) + + // When stop has returned + val stopReturned = new Semaphore(1) + + class BlockingListener extends SparkListener { + override def onJobEnd(jobEnd: SparkListenerJobEnd) = { + listenerWait.acquire() + drained = true + } + } + + val bus = new LiveListenerBus + val blockingListener = new BlockingListener + + bus.addListener(blockingListener) + bus.start() + bus.post(SparkListenerJobEnd(0, JobSucceeded)) + + // the queue should not drain immediately + assert(!drained) + + new Thread("ListenerBusStopper") { + override def run() { + // stop() will block until notify() is called below + bus.stop() + stopReturned.release(1) + } + }.start() + + while (!bus.stopCalled) { + Thread.sleep(10) + } + + listenerWait.release() + stopReturned.acquire() + assert(drained) + } + test("basic creation of StageInfo") { val listener = new SaveStageAndTaskInfo sc.addSparkListener(listener) @@ -171,7 +216,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc test("onTaskGettingResult() called when result fetched remotely") { val listener = new SaveTaskEvents sc.addSparkListener(listener) - + // Make a task whose result is larger than the akka frame size System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = @@ -191,7 +236,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc test("onTaskGettingResult() not called when result sent directly") { val listener = new SaveTaskEvents sc.addSparkListener(listener) - + // Make a task whose result is larger than the akka frame size val result = sc.parallelize(Seq(1), 1).map(2 * _).reduce { case (x, y) => x } assert(result === 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 356e28dd19bc5..2fb750d9ee378 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -264,7 +264,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin test("Scheduler does not always schedule tasks on the same workers") { sc = new SparkContext("local", "TaskSchedulerImplSuite") - val taskScheduler = new TaskSchedulerImpl(sc) + val taskScheduler = new TaskSchedulerImpl(sc) taskScheduler.initialize(new FakeSchedulerBackend) // Need to initialize a DAGScheduler for the taskScheduler to use for callbacks. val dagScheduler = new DAGScheduler(sc, taskScheduler) { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index e83cd55e73691..e10ec7d2624a0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,7 +28,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} @@ -42,6 +42,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT var oldArch: String = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) + val mapOutputTracker = new MapOutputTrackerMaster(conf) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") @@ -96,9 +97,9 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("StorageLevel object caching") { - val level1 = StorageLevel(false, false, false, 3) - val level2 = StorageLevel(false, false, false, 3) // this should return the same object as level1 - val level3 = StorageLevel(false, false, false, 2) // this should return a different object + val level1 = StorageLevel(false, false, false, false, 3) + val level2 = StorageLevel(false, false, false, false, 3) // this should return the same object as level1 + val level3 = StorageLevel(false, false, false, false, 2) // this should return a different object assert(level2 === level1, "level2 is not same as level1") assert(level2.eq(level1), "level2 is not the same object as level1") assert(level3 != level1, "level3 is same as level1") @@ -130,7 +131,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 1 manager interaction") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -160,9 +162,10 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("master + 2 managers interaction") { - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf, - securityMgr) + securityMgr, mapOutputTracker) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -177,7 +180,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing block") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -225,7 +229,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("removing rdd") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -257,9 +262,82 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT master.getLocations(rdd(0, 1)) should have size 0 } + test("removing broadcast") { + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) + val driverStore = store + val executorStore = new BlockManager("executor", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + val a4 = new Array[Byte](400) + + val broadcast0BlockId = BroadcastBlockId(0) + val broadcast1BlockId = BroadcastBlockId(1) + val broadcast2BlockId = BroadcastBlockId(2) + val broadcast2BlockId2 = BroadcastBlockId(2, "_") + + // insert broadcast blocks in both the stores + Seq(driverStore, executorStore).foreach { case s => + s.putSingle(broadcast0BlockId, a1, StorageLevel.DISK_ONLY) + s.putSingle(broadcast1BlockId, a2, StorageLevel.DISK_ONLY) + s.putSingle(broadcast2BlockId, a3, StorageLevel.DISK_ONLY) + s.putSingle(broadcast2BlockId2, a4, StorageLevel.DISK_ONLY) + } + + // verify whether the blocks exist in both the stores + Seq(driverStore, executorStore).foreach { case s => + s.getLocal(broadcast0BlockId) should not be (None) + s.getLocal(broadcast1BlockId) should not be (None) + s.getLocal(broadcast2BlockId) should not be (None) + s.getLocal(broadcast2BlockId2) should not be (None) + } + + // remove broadcast 0 block only from executors + master.removeBroadcast(0, removeFromMaster = false, blocking = true) + + // only broadcast 0 block should be removed from the executor store + executorStore.getLocal(broadcast0BlockId) should be (None) + executorStore.getLocal(broadcast1BlockId) should not be (None) + executorStore.getLocal(broadcast2BlockId) should not be (None) + + // nothing should be removed from the driver store + driverStore.getLocal(broadcast0BlockId) should not be (None) + driverStore.getLocal(broadcast1BlockId) should not be (None) + driverStore.getLocal(broadcast2BlockId) should not be (None) + + // remove broadcast 0 block from the driver as well + master.removeBroadcast(0, removeFromMaster = true, blocking = true) + driverStore.getLocal(broadcast0BlockId) should be (None) + driverStore.getLocal(broadcast1BlockId) should not be (None) + + // remove broadcast 1 block from both the stores asynchronously + // and verify all broadcast 1 blocks have been removed + master.removeBroadcast(1, removeFromMaster = true, blocking = false) + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + driverStore.getLocal(broadcast1BlockId) should be (None) + executorStore.getLocal(broadcast1BlockId) should be (None) + } + + // remove broadcast 2 from both the stores asynchronously + // and verify all broadcast 2 blocks have been removed + master.removeBroadcast(2, removeFromMaster = true, blocking = false) + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + driverStore.getLocal(broadcast2BlockId) should be (None) + driverStore.getLocal(broadcast2BlockId2) should be (None) + executorStore.getLocal(broadcast2BlockId) should be (None) + executorStore.getLocal(broadcast2BlockId2) should be (None) + } + executorStore.stop() + driverStore.stop() + store = null + } + test("reregistration on heart beat") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) @@ -275,7 +353,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("reregistration on block update") { - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) @@ -294,7 +373,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("reregistration doesn't dead lock") { val heartBeat = PrivateMethod[Unit]('heartBeat) - store = new BlockManager("", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = List(new Array[Byte](400)) @@ -331,7 +411,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -350,7 +431,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -369,7 +451,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of same RDD") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -388,7 +471,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU for partitions of multiple RDDs") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(0, 2), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 1), new Array[Byte](400), StorageLevel.MEMORY_ONLY) @@ -410,8 +494,29 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(store.memoryStore.contains(rdd(0, 3)), "rdd_0_3 was not in store") } + test("tachyon storage") { + // TODO Make the spark.test.tachyon.enable true after using tachyon 0.5.0 testing jar. + val tachyonUnitTestEnabled = conf.getBoolean("spark.test.tachyon.enable", false) + if (tachyonUnitTestEnabled) { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + val a1 = new Array[Byte](400) + val a2 = new Array[Byte](400) + val a3 = new Array[Byte](400) + store.putSingle("a1", a1, StorageLevel.OFF_HEAP) + store.putSingle("a2", a2, StorageLevel.OFF_HEAP) + store.putSingle("a3", a3, StorageLevel.OFF_HEAP) + assert(store.getSingle("a3").isDefined, "a3 was in store") + assert(store.getSingle("a2").isDefined, "a2 was in store") + assert(store.getSingle("a1").isDefined, "a1 was in store") + } else { + info("tachyon storage test disabled.") + } + } + test("on-disk storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -424,7 +529,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -439,7 +545,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -454,7 +561,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -469,7 +577,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("disk and memory storage with serialization and getLocalBytes") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -484,7 +593,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val a1 = new Array[Byte](400) val a2 = new Array[Byte](400) val a3 = new Array[Byte](400) @@ -506,7 +616,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("in-memory LRU with streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -530,7 +641,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("LRU with mixed storage levels and streams") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val list1 = List(new Array[Byte](200), new Array[Byte](200)) val list2 = List(new Array[Byte](200), new Array[Byte](200)) val list3 = List(new Array[Byte](200), new Array[Byte](200)) @@ -576,7 +688,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("overly large block") { - store = new BlockManager("", actorSystem, master, serializer, 500, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 500, conf, + securityMgr, mapOutputTracker) store.putSingle("a1", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.getSingle("a1") === None, "a1 was in store") store.putSingle("a2", new Array[Byte](1000), StorageLevel.MEMORY_AND_DISK) @@ -587,7 +700,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block compression") { try { conf.set("spark.shuffle.compress", "true") - store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) <= 100, "shuffle_0_0_0 was not compressed") @@ -595,7 +709,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.shuffle.compress", "false") - store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec2", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(ShuffleBlockId(0, 0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(ShuffleBlockId(0, 0, 0)) >= 1000, "shuffle_0_0_0 was compressed") @@ -603,7 +718,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "true") - store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec3", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) <= 100, "broadcast_0 was not compressed") @@ -611,28 +727,32 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT store = null conf.set("spark.broadcast.compress", "false") - store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec4", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(BroadcastBlockId(0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(BroadcastBlockId(0)) >= 1000, "broadcast_0 was compressed") store.stop() store = null conf.set("spark.rdd.compress", "true") - store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec5", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) <= 100, "rdd_0_0 was not compressed") store.stop() store = null conf.set("spark.rdd.compress", "false") - store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec6", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 0), new Array[Byte](1000), StorageLevel.MEMORY_ONLY_SER) assert(store.memoryStore.getSize(rdd(0, 0)) >= 1000, "rdd_0_0 was compressed") store.stop() store = null // Check that any other block types are also kept uncompressed - store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, securityMgr) + store = new BlockManager("exec7", actorSystem, master, serializer, 2000, conf, + securityMgr, mapOutputTracker) store.putSingle("other_block", new Array[Byte](1000), StorageLevel.MEMORY_ONLY) assert(store.memoryStore.getSize("other_block") >= 1000, "other_block was compressed") store.stop() @@ -647,7 +767,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block store put failure") { // Use Java serializer so we can create an unserializable error. store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf, - securityMgr) + securityMgr, mapOutputTracker) // The put should fail since a1 is not serializable. class UnserializableClass @@ -663,7 +783,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT } test("updated block statuses") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) val list = List.fill(2)(new Array[Byte](200)) val bigList = List.fill(8)(new Array[Byte](200)) @@ -716,8 +837,83 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(!store.get("list5").isDefined, "list5 was in store") } + test("query block statuses") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + val list = List.fill(2)(new Array[Byte](200)) + + // Tell master. By LRU, only list2 and list3 remains. + store.put("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getLocations("list1").size === 0) + assert(store.master.getLocations("list2").size === 1) + assert(store.master.getLocations("list3").size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) + + // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. + store.put("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.put("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.put("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + + // getLocations should return nothing because the master is not informed + // getBlockStatus without asking slaves should have the same result + // getBlockStatus with asking slaves, however, should return the actual block statuses + assert(store.master.getLocations("list4").size === 0) + assert(store.master.getLocations("list5").size === 0) + assert(store.master.getLocations("list6").size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list6", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list6", askSlaves = true).size === 1) + } + + test("get matching blocks") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + val list = List.fill(2)(new Array[Byte](10)) + + // insert some blocks + store.put("list1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getMatchingBlockIds(_.toString.contains("list"), askSlaves = false).size === 3) + assert(store.master.getMatchingBlockIds(_.toString.contains("list1"), askSlaves = false).size === 1) + + // insert some more blocks + store.put("newlist1", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("newlist2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.put("newlist3", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = false).size === 1) + assert(store.master.getMatchingBlockIds(_.toString.contains("newlist"), askSlaves = true).size === 3) + + val blockIds = Seq(RDDBlockId(1, 0), RDDBlockId(1, 1), RDDBlockId(2, 0)) + blockIds.foreach { blockId => + store.put(blockId, list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + } + val matchedBlockIds = store.master.getMatchingBlockIds(_ match { + case RDDBlockId(1, _) => true + case _ => false + }, askSlaves = true) + assert(matchedBlockIds.toSet === Set(RDDBlockId(1, 0), RDDBlockId(1, 1))) + } + test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { - store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr) + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) store.putSingle(rdd(0, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) store.putSingle(rdd(1, 0), new Array[Byte](400), StorageLevel.MEMORY_ONLY) // Access rdd_1_0 to ensure it's not least recently used. diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 62f9b3cc7b2c1..808ddfdcf45d8 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -59,8 +59,16 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { val newFile = diskBlockManager.getFile(blockId) writeToFile(newFile, 10) assertSegmentEquals(blockId, blockId.name, 0, 10) - + assert(diskBlockManager.containsBlock(blockId)) newFile.delete() + assert(!diskBlockManager.containsBlock(blockId)) + } + + test("enumerating blocks") { + val ids = (1 to 100).map(i => TestBlockId("test_" + i)) + val files = ids.map(id => diskBlockManager.getFile(id)) + files.foreach(file => writeToFile(file, 10)) + assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } test("block appending") { diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 45c322427930d..b85c483ca2a08 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -18,23 +18,88 @@ package org.apache.spark.ui import java.net.ServerSocket +import javax.servlet.http.HttpServletRequest +import scala.io.Source import scala.util.{Failure, Success, Try} import org.eclipse.jetty.server.Server import org.eclipse.jetty.servlet.ServletContextHandler import org.scalatest.FunSuite +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.LocalSparkContext._ +import scala.xml.Node class UISuite extends FunSuite { + + test("basic ui visibility") { + withSpark(new SparkContext("local", "test")) { sc => + // test if the ui is visible, and all the expected tabs are visible + eventually(timeout(10 seconds), interval(50 milliseconds)) { + val html = Source.fromURL(sc.ui.appUIAddress).mkString + assert(!html.contains("random data that should not be present")) + assert(html.toLowerCase.contains("stages")) + assert(html.toLowerCase.contains("storage")) + assert(html.toLowerCase.contains("environment")) + assert(html.toLowerCase.contains("executors")) + } + } + } + + test("visibility at localhost:4040") { + withSpark(new SparkContext("local", "test")) { sc => + // test if visible from http://localhost:4040 + eventually(timeout(10 seconds), interval(50 milliseconds)) { + val html = Source.fromURL("http://localhost:4040").mkString + assert(html.toLowerCase.contains("stages")) + } + } + } + + test("attaching a new tab") { + withSpark(new SparkContext("local", "test")) { sc => + val sparkUI = sc.ui + + val newTab = new WebUITab(sparkUI, "foo") { + attachPage(new WebUIPage("") { + def render(request: HttpServletRequest): Seq[Node] = { + "html magic" + } + }) + } + sparkUI.attachTab(newTab) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + val html = Source.fromURL(sc.ui.appUIAddress).mkString + assert(!html.contains("random data that should not be present")) + + // check whether new page exists + assert(html.toLowerCase.contains("foo")) + + // check whether other pages still exist + assert(html.toLowerCase.contains("stages")) + assert(html.toLowerCase.contains("storage")) + assert(html.toLowerCase.contains("environment")) + assert(html.toLowerCase.contains("executors")) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + val html = Source.fromURL(sc.ui.appUIAddress.stripSuffix("/") + "/foo").mkString + // check whether new page exists + assert(html.contains("magic")) + } + } + } + test("jetty port increases under contention") { val startPort = 4040 val server = new Server(startPort) Try { server.start() } match { - case Success(s) => - case Failure(e) => + case Success(s) => + case Failure(e) => // Either case server port is busy hence setup for test complete } val serverInfo1 = JettyUtils.startJettyServer( @@ -60,4 +125,18 @@ class UISuite extends FunSuite { case Failure(e) => } } + + test("verify appUIAddress contains the scheme") { + withSpark(new SparkContext("local", "test")) { sc => + val uiAddress = sc.ui.appUIAddress + assert(uiAddress.equals("http://" + sc.ui.appUIHostPort)) + } + } + + test("verify appUIAddress contains the port") { + withSpark(new SparkContext("local", "test")) { sc => + val splitUIAddress = sc.ui.appUIAddress.split(':') + assert(splitUIAddress(2).toInt == sc.ui.boundPort) + } + } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index d8a3e859f85cd..8c06a2d9aa4ab 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -18,16 +18,45 @@ package org.apache.spark.ui.jobs import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers -import org.apache.spark.{LocalSparkContext, SparkContext, Success} +import org.apache.spark.{LocalSparkContext, SparkConf, Success} import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils -class JobProgressListenerSuite extends FunSuite with LocalSparkContext { +class JobProgressListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + test("test LRU eviction of stages") { + val conf = new SparkConf() + conf.set("spark.ui.retainedStages", 5.toString) + val listener = new JobProgressListener(conf) + + def createStageStartEvent(stageId: Int) = { + val stageInfo = new StageInfo(stageId, stageId.toString, 0, null) + SparkListenerStageSubmitted(stageInfo) + } + + def createStageEndEvent(stageId: Int) = { + val stageInfo = new StageInfo(stageId, stageId.toString, 0, null) + SparkListenerStageCompleted(stageInfo) + } + + for (i <- 1 to 50) { + listener.onStageSubmitted(createStageStartEvent(i)) + listener.onStageCompleted(createStageEndEvent(i)) + } + + listener.completedStages.size should be (5) + listener.completedStages.filter(_.stageId == 50).size should be (1) + listener.completedStages.filter(_.stageId == 49).size should be (1) + listener.completedStages.filter(_.stageId == 48).size should be (1) + listener.completedStages.filter(_.stageId == 47).size should be (1) + listener.completedStages.filter(_.stageId == 46).size should be (1) + } + test("test executor id to summary") { - val sc = new SparkContext("local", "test") - val listener = new JobProgressListener(sc.conf) + val conf = new SparkConf() + val listener = new JobProgressListener(conf) val taskMetrics = new TaskMetrics() val shuffleReadMetrics = new ShuffleReadMetrics() diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 439e5644e20a3..d7e48e633e0ee 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -69,7 +69,7 @@ object TestObject { class TestClass extends Serializable { var x = 5 - + def getX = x def run(): Int = { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 67c0a434c9b52..16470bb7bf60d 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import java.util.{Properties, UUID} +import java.util.Properties import scala.collection.Map @@ -52,6 +52,8 @@ class JsonProtocolSuite extends FunSuite { val blockManagerRemoved = SparkListenerBlockManagerRemoved( BlockManagerId("Scarce", "to be counted...", 100, 200)) val unpersistRdd = SparkListenerUnpersistRDD(12345) + val applicationStart = SparkListenerApplicationStart("The winner of all", 42L, "Garfield") + val applicationEnd = SparkListenerApplicationEnd(42L) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -64,6 +66,8 @@ class JsonProtocolSuite extends FunSuite { testEvent(blockManagerAdded, blockManagerAddedJsonString) testEvent(blockManagerRemoved, blockManagerRemovedJsonString) testEvent(unpersistRdd, unpersistRDDJsonString) + testEvent(applicationStart, applicationStartJsonString) + testEvent(applicationEnd, applicationEndJsonString) } test("Dependent Classes") { @@ -89,7 +93,7 @@ class JsonProtocolSuite extends FunSuite { // JobResult val exception = new Exception("Out of Memory! Please restock film.") exception.setStackTrace(stackTrace) - val jobFailed = JobFailed(exception, 2) + val jobFailed = JobFailed(exception) testJobResult(JobSucceeded) testJobResult(jobFailed) @@ -108,11 +112,9 @@ class JsonProtocolSuite extends FunSuite { // BlockId testBlockId(RDDBlockId(1, 2)) testBlockId(ShuffleBlockId(1, 2, 3)) - testBlockId(BroadcastBlockId(1L)) - testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark")) + testBlockId(BroadcastBlockId(1L, "insert_words_of_wisdom_here")) testBlockId(TaskResultBlockId(1L)) testBlockId(StreamBlockId(1, 2L)) - testBlockId(TempBlockId(UUID.randomUUID())) } @@ -168,8 +170,8 @@ class JsonProtocolSuite extends FunSuite { } private def testBlockId(blockId: BlockId) { - val newBlockId = JsonProtocol.blockIdFromJson(JsonProtocol.blockIdToJson(blockId)) - blockId == newBlockId + val newBlockId = BlockId(blockId.toString) + assert(blockId === newBlockId) } @@ -180,90 +182,96 @@ class JsonProtocolSuite extends FunSuite { private def assertEquals(event1: SparkListenerEvent, event2: SparkListenerEvent) { (event1, event2) match { case (e1: SparkListenerStageSubmitted, e2: SparkListenerStageSubmitted) => - assert(e1.properties == e2.properties) + assert(e1.properties === e2.properties) assertEquals(e1.stageInfo, e2.stageInfo) case (e1: SparkListenerStageCompleted, e2: SparkListenerStageCompleted) => assertEquals(e1.stageInfo, e2.stageInfo) case (e1: SparkListenerTaskStart, e2: SparkListenerTaskStart) => - assert(e1.stageId == e2.stageId) + assert(e1.stageId === e2.stageId) assertEquals(e1.taskInfo, e2.taskInfo) case (e1: SparkListenerTaskGettingResult, e2: SparkListenerTaskGettingResult) => assertEquals(e1.taskInfo, e2.taskInfo) case (e1: SparkListenerTaskEnd, e2: SparkListenerTaskEnd) => - assert(e1.stageId == e2.stageId) - assert(e1.taskType == e2.taskType) + assert(e1.stageId === e2.stageId) + assert(e1.taskType === e2.taskType) assertEquals(e1.reason, e2.reason) assertEquals(e1.taskInfo, e2.taskInfo) assertEquals(e1.taskMetrics, e2.taskMetrics) case (e1: SparkListenerJobStart, e2: SparkListenerJobStart) => - assert(e1.jobId == e2.jobId) - assert(e1.properties == e2.properties) - assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 == i2)) + assert(e1.jobId === e2.jobId) + assert(e1.properties === e2.properties) + assertSeqEquals(e1.stageIds, e2.stageIds, (i1: Int, i2: Int) => assert(i1 === i2)) case (e1: SparkListenerJobEnd, e2: SparkListenerJobEnd) => - assert(e1.jobId == e2.jobId) + assert(e1.jobId === e2.jobId) assertEquals(e1.jobResult, e2.jobResult) case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) => assertEquals(e1.environmentDetails, e2.environmentDetails) case (e1: SparkListenerBlockManagerAdded, e2: SparkListenerBlockManagerAdded) => - assert(e1.maxMem == e2.maxMem) + assert(e1.maxMem === e2.maxMem) assertEquals(e1.blockManagerId, e2.blockManagerId) case (e1: SparkListenerBlockManagerRemoved, e2: SparkListenerBlockManagerRemoved) => assertEquals(e1.blockManagerId, e2.blockManagerId) case (e1: SparkListenerUnpersistRDD, e2: SparkListenerUnpersistRDD) => assert(e1.rddId == e2.rddId) + case (e1: SparkListenerApplicationStart, e2: SparkListenerApplicationStart) => + assert(e1.appName == e2.appName) + assert(e1.time == e2.time) + assert(e1.sparkUser == e2.sparkUser) + case (e1: SparkListenerApplicationEnd, e2: SparkListenerApplicationEnd) => + assert(e1.time == e2.time) case (SparkListenerShutdown, SparkListenerShutdown) => case _ => fail("Events don't match in types!") } } private def assertEquals(info1: StageInfo, info2: StageInfo) { - assert(info1.stageId == info2.stageId) - assert(info1.name == info2.name) - assert(info1.numTasks == info2.numTasks) - assert(info1.submissionTime == info2.submissionTime) - assert(info1.completionTime == info2.completionTime) - assert(info1.emittedTaskSizeWarning == info2.emittedTaskSizeWarning) + assert(info1.stageId === info2.stageId) + assert(info1.name === info2.name) + assert(info1.numTasks === info2.numTasks) + assert(info1.submissionTime === info2.submissionTime) + assert(info1.completionTime === info2.completionTime) + assert(info1.emittedTaskSizeWarning === info2.emittedTaskSizeWarning) assertEquals(info1.rddInfo, info2.rddInfo) } private def assertEquals(info1: RDDInfo, info2: RDDInfo) { - assert(info1.id == info2.id) - assert(info1.name == info2.name) - assert(info1.numPartitions == info2.numPartitions) - assert(info1.numCachedPartitions == info2.numCachedPartitions) - assert(info1.memSize == info2.memSize) - assert(info1.diskSize == info2.diskSize) + assert(info1.id === info2.id) + assert(info1.name === info2.name) + assert(info1.numPartitions === info2.numPartitions) + assert(info1.numCachedPartitions === info2.numCachedPartitions) + assert(info1.memSize === info2.memSize) + assert(info1.diskSize === info2.diskSize) assertEquals(info1.storageLevel, info2.storageLevel) } private def assertEquals(level1: StorageLevel, level2: StorageLevel) { - assert(level1.useDisk == level2.useDisk) - assert(level1.useMemory == level2.useMemory) - assert(level1.deserialized == level2.deserialized) - assert(level1.replication == level2.replication) + assert(level1.useDisk === level2.useDisk) + assert(level1.useMemory === level2.useMemory) + assert(level1.deserialized === level2.deserialized) + assert(level1.replication === level2.replication) } private def assertEquals(info1: TaskInfo, info2: TaskInfo) { - assert(info1.taskId == info2.taskId) - assert(info1.index == info2.index) - assert(info1.launchTime == info2.launchTime) - assert(info1.executorId == info2.executorId) - assert(info1.host == info2.host) - assert(info1.taskLocality == info2.taskLocality) - assert(info1.gettingResultTime == info2.gettingResultTime) - assert(info1.finishTime == info2.finishTime) - assert(info1.failed == info2.failed) - assert(info1.serializedSize == info2.serializedSize) + assert(info1.taskId === info2.taskId) + assert(info1.index === info2.index) + assert(info1.launchTime === info2.launchTime) + assert(info1.executorId === info2.executorId) + assert(info1.host === info2.host) + assert(info1.taskLocality === info2.taskLocality) + assert(info1.gettingResultTime === info2.gettingResultTime) + assert(info1.finishTime === info2.finishTime) + assert(info1.failed === info2.failed) + assert(info1.serializedSize === info2.serializedSize) } private def assertEquals(metrics1: TaskMetrics, metrics2: TaskMetrics) { - assert(metrics1.hostname == metrics2.hostname) - assert(metrics1.executorDeserializeTime == metrics2.executorDeserializeTime) - assert(metrics1.resultSize == metrics2.resultSize) - assert(metrics1.jvmGCTime == metrics2.jvmGCTime) - assert(metrics1.resultSerializationTime == metrics2.resultSerializationTime) - assert(metrics1.memoryBytesSpilled == metrics2.memoryBytesSpilled) - assert(metrics1.diskBytesSpilled == metrics2.diskBytesSpilled) + assert(metrics1.hostname === metrics2.hostname) + assert(metrics1.executorDeserializeTime === metrics2.executorDeserializeTime) + assert(metrics1.resultSize === metrics2.resultSize) + assert(metrics1.jvmGCTime === metrics2.jvmGCTime) + assert(metrics1.resultSerializationTime === metrics2.resultSerializationTime) + assert(metrics1.memoryBytesSpilled === metrics2.memoryBytesSpilled) + assert(metrics1.diskBytesSpilled === metrics2.diskBytesSpilled) assertOptionEquals( metrics1.shuffleReadMetrics, metrics2.shuffleReadMetrics, assertShuffleReadEquals) assertOptionEquals( @@ -272,31 +280,30 @@ class JsonProtocolSuite extends FunSuite { } private def assertEquals(metrics1: ShuffleReadMetrics, metrics2: ShuffleReadMetrics) { - assert(metrics1.shuffleFinishTime == metrics2.shuffleFinishTime) - assert(metrics1.totalBlocksFetched == metrics2.totalBlocksFetched) - assert(metrics1.remoteBlocksFetched == metrics2.remoteBlocksFetched) - assert(metrics1.localBlocksFetched == metrics2.localBlocksFetched) - assert(metrics1.fetchWaitTime == metrics2.fetchWaitTime) - assert(metrics1.remoteBytesRead == metrics2.remoteBytesRead) + assert(metrics1.shuffleFinishTime === metrics2.shuffleFinishTime) + assert(metrics1.totalBlocksFetched === metrics2.totalBlocksFetched) + assert(metrics1.remoteBlocksFetched === metrics2.remoteBlocksFetched) + assert(metrics1.localBlocksFetched === metrics2.localBlocksFetched) + assert(metrics1.fetchWaitTime === metrics2.fetchWaitTime) + assert(metrics1.remoteBytesRead === metrics2.remoteBytesRead) } private def assertEquals(metrics1: ShuffleWriteMetrics, metrics2: ShuffleWriteMetrics) { - assert(metrics1.shuffleBytesWritten == metrics2.shuffleBytesWritten) - assert(metrics1.shuffleWriteTime == metrics2.shuffleWriteTime) + assert(metrics1.shuffleBytesWritten === metrics2.shuffleBytesWritten) + assert(metrics1.shuffleWriteTime === metrics2.shuffleWriteTime) } private def assertEquals(bm1: BlockManagerId, bm2: BlockManagerId) { - assert(bm1.executorId == bm2.executorId) - assert(bm1.host == bm2.host) - assert(bm1.port == bm2.port) - assert(bm1.nettyPort == bm2.nettyPort) + assert(bm1.executorId === bm2.executorId) + assert(bm1.host === bm2.host) + assert(bm1.port === bm2.port) + assert(bm1.nettyPort === bm2.nettyPort) } private def assertEquals(result1: JobResult, result2: JobResult) { (result1, result2) match { case (JobSucceeded, JobSucceeded) => case (r1: JobFailed, r2: JobFailed) => - assert(r1.failedStageId == r2.failedStageId) assertEquals(r1.exception, r2.exception) case _ => fail("Job results don't match in types!") } @@ -307,13 +314,13 @@ class JsonProtocolSuite extends FunSuite { case (Success, Success) => case (Resubmitted, Resubmitted) => case (r1: FetchFailed, r2: FetchFailed) => - assert(r1.shuffleId == r2.shuffleId) - assert(r1.mapId == r2.mapId) - assert(r1.reduceId == r2.reduceId) + assert(r1.shuffleId === r2.shuffleId) + assert(r1.mapId === r2.mapId) + assert(r1.reduceId === r2.reduceId) assertEquals(r1.bmAddress, r2.bmAddress) case (r1: ExceptionFailure, r2: ExceptionFailure) => - assert(r1.className == r2.className) - assert(r1.description == r2.description) + assert(r1.className === r2.className) + assert(r1.description === r2.description) assertSeqEquals(r1.stackTrace, r2.stackTrace, assertStackTraceElementEquals) assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals) case (TaskResultLost, TaskResultLost) => @@ -329,13 +336,13 @@ class JsonProtocolSuite extends FunSuite { details2: Map[String, Seq[(String, String)]]) { details1.zip(details2).foreach { case ((key1, values1: Seq[(String, String)]), (key2, values2: Seq[(String, String)])) => - assert(key1 == key2) - values1.zip(values2).foreach { case (v1, v2) => assert(v1 == v2) } + assert(key1 === key2) + values1.zip(values2).foreach { case (v1, v2) => assert(v1 === v2) } } } private def assertEquals(exception1: Exception, exception2: Exception) { - assert(exception1.getMessage == exception2.getMessage) + assert(exception1.getMessage === exception2.getMessage) assertSeqEquals( exception1.getStackTrace, exception2.getStackTrace, @@ -344,11 +351,11 @@ class JsonProtocolSuite extends FunSuite { private def assertJsonStringEquals(json1: String, json2: String) { val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - formatJsonString(json1) == formatJsonString(json2) + formatJsonString(json1) === formatJsonString(json2) } private def assertSeqEquals[T](seq1: Seq[T], seq2: Seq[T], assertEquals: (T, T) => Unit) { - assert(seq1.length == seq2.length) + assert(seq1.length === seq2.length) seq1.zip(seq2).foreach { case (t1, t2) => assertEquals(t1, t2) } @@ -389,11 +396,11 @@ class JsonProtocolSuite extends FunSuite { } private def assertBlockEquals(b1: (BlockId, BlockStatus), b2: (BlockId, BlockStatus)) { - assert(b1 == b2) + assert(b1 === b2) } private def assertStackTraceElementEquals(ste1: StackTraceElement, ste2: StackTraceElement) { - assert(ste1 == ste2) + assert(ste1 === ste2) } @@ -457,7 +464,7 @@ class JsonProtocolSuite extends FunSuite { t.shuffleWriteMetrics = Some(sw) // Make at most 6 blocks t.updatedBlocks = Some((1 to (e % 5 + 1)).map { i => - (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i)) + (RDDBlockId(e % i, f % i), BlockStatus(StorageLevel.MEMORY_AND_DISK_SER_2, a % i, b % i, c%i)) }.toSeq) t } @@ -471,19 +478,19 @@ class JsonProtocolSuite extends FunSuite { """ {"Event":"SparkListenerStageSubmitted","Stage Info":{"Stage ID":100,"Stage Name": "greetings","Number of Tasks":200,"RDD Info":{"RDD ID":100,"Name":"mayor","Storage - Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1}, - "Number of Partitions":200,"Number of Cached Partitions":300,"Memory Size":400, - "Disk Size":500},"Emitted Task Size Warning":false},"Properties":{"France":"Paris", - "Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}} + Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true, + "Replication":1},"Number of Partitions":200,"Number of Cached Partitions":300, + "Memory Size":400,"Disk Size":500,"Tachyon Size":0},"Emitted Task Size Warning":false}, + "Properties":{"France":"Paris","Germany":"Berlin","Russia":"Moscow","Ukraine":"Kiev"}} """ private val stageCompletedJsonString = """ {"Event":"SparkListenerStageCompleted","Stage Info":{"Stage ID":101,"Stage Name": "greetings","Number of Tasks":201,"RDD Info":{"RDD ID":101,"Name":"mayor","Storage - Level":{"Use Disk":true,"Use Memory":true,"Deserialized":true,"Replication":1}, - "Number of Partitions":201,"Number of Cached Partitions":301,"Memory Size":401, - "Disk Size":501},"Emitted Task Size Warning":false}} + Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false,"Deserialized":true, + "Replication":1},"Number of Partitions":201,"Number of Cached Partitions":301, + "Memory Size":401,"Disk Size":501,"Tachyon Size":0},"Emitted Task Size Warning":false}} """ private val taskStartJsonString = @@ -516,8 +523,8 @@ class JsonProtocolSuite extends FunSuite { 700,"Fetch Wait Time":900,"Remote Bytes Read":1000},"Shuffle Write Metrics": {"Shuffle Bytes Written":1200,"Shuffle Write Time":1500},"Updated Blocks": [{"Block ID":{"Type":"RDDBlockId","RDD ID":0,"Split Index":0},"Status": - {"Storage Level":{"Use Disk":true,"Use Memory":true,"Deserialized":false, - "Replication":2},"Memory Size":0,"Disk Size":0}}]}} + {"Storage Level":{"Use Disk":true,"Use Memory":true,"Use Tachyon":false, + "Deserialized":false,"Replication":2},"Memory Size":0,"Disk Size":0,"Tachyon Size":0}}]}} """ private val jobStartJsonString = @@ -556,4 +563,14 @@ class JsonProtocolSuite extends FunSuite { {"Event":"SparkListenerUnpersistRDD","RDD ID":12345} """ - } + private val applicationStartJsonString = + """ + {"Event":"SparkListenerApplicationStart","App Name":"The winner of all","Timestamp":42, + "User":"Garfield"} + """ + + private val applicationEndJsonString = + """ + {"Event":"SparkListenerApplicationEnd","Timestamp":42} + """ +} diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala index e1446cbc90bdb..32d74d0500b72 100644 --- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala @@ -32,7 +32,7 @@ class NextIteratorSuite extends FunSuite with ShouldMatchers { i.hasNext should be === false intercept[NoSuchElementException] { i.next() } } - + test("two iterations") { val i = new StubIterator(Buffer(1, 2)) i.hasNext should be === true @@ -70,7 +70,7 @@ class NextIteratorSuite extends FunSuite with ShouldMatchers { class StubIterator(ints: Buffer[Int]) extends NextIterator[Int] { var closeCalled = 0 - + override def getNext() = { if (ints.size == 0) { finished = true diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala new file mode 100644 index 0000000000000..6a5653ed2fb54 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.ref.WeakReference + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.scalatest.FunSuite + +class TimeStampedHashMapSuite extends FunSuite { + + // Test the testMap function - a Scala HashMap should obviously pass + testMap(new mutable.HashMap[String, String]()) + + // Test TimeStampedHashMap basic functionality + testMap(new TimeStampedHashMap[String, String]()) + testMapThreadSafety(new TimeStampedHashMap[String, String]()) + + // Test TimeStampedWeakValueHashMap basic functionality + testMap(new TimeStampedWeakValueHashMap[String, String]()) + testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]()) + + test("TimeStampedHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing weak references") { + var strongRef = new Object + val weakRef = new WeakReference(strongRef) + val map = new TimeStampedWeakValueHashMap[String, Object] + map("k1") = strongRef + map("k2") = "v2" + map("k3") = "v3" + assert(map("k1") === strongRef) + + // clear strong reference to "k1" + strongRef = null + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + assert(map.getReference("k1").isDefined) + val ref = map.getReference("k1").get + assert(ref.get === null) + assert(map.get("k1") === None) + + // operations should only display non-null entries + assert(map.iterator.forall { case (k, v) => k != "k1" }) + assert(map.filter { case (k, v) => k != "k2" }.size === 1) + assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3") + assert(map.toMap.size === 2) + assert(map.toMap.forall { case (k, v) => k != "k1" }) + val buffer = new ArrayBuffer[String] + map.foreach { case (k, v) => buffer += v.toString } + assert(buffer.size === 2) + assert(buffer.forall(_ != "k1")) + val plusMap = map + (("k4", "v4")) + assert(plusMap.size === 3) + assert(plusMap.forall { case (k, v) => k != "k1" }) + val minusMap = map - "k2" + assert(minusMap.size === 1) + assert(minusMap.head._1 == "k3") + + // clear null values - should only clear k1 + map.clearNullValues() + assert(map.getReference("k1") === None) + assert(map.get("k1") === None) + assert(map.get("k2").isDefined) + assert(map.get("k2").get === "v2") + } + + /** Test basic operations of a Scala mutable Map. */ + def testMap(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val testMap1 = newMap() + val testMap2 = newMap() + val name = testMap1.getClass.getSimpleName + + test(name + " - basic test") { + // put, get, and apply + testMap1 += (("k1", "v1")) + assert(testMap1.get("k1").isDefined) + assert(testMap1.get("k1").get === "v1") + testMap1("k2") = "v2" + assert(testMap1.get("k2").isDefined) + assert(testMap1.get("k2").get === "v2") + assert(testMap1("k2") === "v2") + testMap1.update("k3", "v3") + assert(testMap1.get("k3").isDefined) + assert(testMap1.get("k3").get === "v3") + + // remove + testMap1.remove("k1") + assert(testMap1.get("k1").isEmpty) + testMap1.remove("k2") + intercept[NoSuchElementException] { + testMap1("k2") // Map.apply() causes exception + } + testMap1 -= "k3" + assert(testMap1.get("k3").isEmpty) + + // multi put + val keys = (1 to 100).map(_.toString) + val pairs = keys.map(x => (x, x * 2)) + assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) + testMap2 ++= pairs + + // iterator + assert(testMap2.iterator.toSet === pairs.toSet) + + // filter + val filtered = testMap2.filter { case (_, v) => v.toInt % 2 == 0 } + val evenPairs = pairs.filter { case (_, v) => v.toInt % 2 == 0 } + assert(filtered.iterator.toSet === evenPairs.toSet) + + // foreach + val buffer = new ArrayBuffer[(String, String)] + testMap2.foreach(x => buffer += x) + assert(testMap2.toSet === buffer.toSet) + + // multi remove + testMap2("k1") = "v1" + testMap2 --= keys + assert(testMap2.size === 1) + assert(testMap2.iterator.toSeq.head === ("k1", "v1")) + + // + + val testMap3 = testMap2 + (("k0", "v0")) + assert(testMap3.size === 2) + assert(testMap3.get("k1").isDefined) + assert(testMap3.get("k1").get === "v1") + assert(testMap3.get("k0").isDefined) + assert(testMap3.get("k0").get === "v0") + + // - + val testMap4 = testMap3 - "k0" + assert(testMap4.size === 1) + assert(testMap4.get("k1").isDefined) + assert(testMap4.get("k1").get === "v1") + } + } + + /** Test thread safety of a Scala mutable map. */ + def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val name = newMap().getClass.getSimpleName + val testMap = newMap() + @volatile var error = false + + def getRandomKey(m: mutable.Map[String, String]): Option[String] = { + val keys = testMap.keysIterator.toSeq + if (keys.nonEmpty) { + Some(keys(Random.nextInt(keys.size))) + } else { + None + } + } + + val threads = (1 to 25).map(i => new Thread() { + override def run() { + try { + for (j <- 1 to 1000) { + Random.nextInt(3) match { + case 0 => + testMap(Random.nextString(10)) = Random.nextDouble().toString // put + case 1 => + getRandomKey(testMap).map(testMap.get) // get + case 2 => + getRandomKey(testMap).map(testMap.remove) // remove + } + } + } catch { + case t: Throwable => + error = true + throw t + } + } + }) + + test(name + " - threading safety test") { + threads.map(_.start) + threads.map(_.join) + assert(!error) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 616214fb5e3a6..eb7fb6318262b 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import scala.util.Random -import java.io.{ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} import java.nio.{ByteBuffer, ByteOrder} import com.google.common.base.Charsets @@ -154,5 +154,18 @@ class UtilsSuite extends FunSuite { val iterator = Iterator.range(0, 5) assert(Utils.getIteratorSize(iterator) === 5L) } + + test("findOldFiles") { + // create some temporary directories and files + val parent: File = Utils.createTempDir() + val child1: File = Utils.createTempDir(parent.getCanonicalPath) // The parent directory has two child directories + val child2: File = Utils.createTempDir(parent.getCanonicalPath) + // set the last modified time of child1 to 10 secs old + child1.setLastModified(System.currentTimeMillis() - (1000 * 10)) + + val result = Utils.findOldFiles(parent, 5) // find files older than 5 secs + assert(result.size.equals(1)) + assert(result(0).getCanonicalPath.equals(child1.getCanonicalPath)) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index fce1184d46364..cdebefb67510c 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -174,9 +174,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { assert(result1.toSet == Set[(Int, Int)]((0, 5), (1, 5))) // groupByKey - val result2 = rdd.groupByKey().collect() + val result2 = rdd.groupByKey().collect().map(x => (x._1, x._2.toList)).toSet assert(result2.toSet == Set[(Int, Seq[Int])] - ((0, ArrayBuffer[Int](1, 1, 1, 1, 1)), (1, ArrayBuffer[Int](1, 1, 1, 1, 1)))) + ((0, List[Int](1, 1, 1, 1, 1)), (1, List[Int](1, 1, 1, 1, 1)))) } test("simple cogroup") { diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index 757476efdb789..39199a1a17ccd 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -29,12 +29,12 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers { val xorRand = new XORShiftRandom(seed) val hundMil = 1e8.toInt } - + /* - * This test is based on a chi-squared test for randomness. The values are hard-coded + * This test is based on a chi-squared test for randomness. The values are hard-coded * so as not to create Spark's dependency on apache.commons.math3 just to call one * method for calculating the exact p-value for a given number of random numbers - * and bins. In case one would want to move to a full-fledged test based on + * and bins. In case one would want to move to a full-fledged test based on * apache.commons.math3, the relevant class is here: * org.apache.commons.math3.stat.inference.ChiSquareTest */ @@ -49,19 +49,19 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers { // populate bins based on modulus of the random number times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1} - /* since the seed is deterministic, until the algorithm is changed, we know the result will be - * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, - * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) - * significance level. However, should the RNG implementation change, the test should still - * pass at the same significance level. The chi-squared test done in R gave the following + /* since the seed is deterministic, until the algorithm is changed, we know the result will be + * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) + * significance level. However, should the RNG implementation change, the test should still + * pass at the same significance level. The chi-squared test done in R gave the following * results: * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, * 10000790, 10002286, 9998699)) * Chi-squared test for given probabilities - * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, + * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, * 10002286, 9998699) * X-squared = 11.975, df = 9, p-value = 0.2147 - * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million + * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million * random numbers * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared * is greater than or equal to that number. diff --git a/dev/audit-release/README.md b/dev/audit-release/README.md index 2437a98672177..38becda0eae92 100644 --- a/dev/audit-release/README.md +++ b/dev/audit-release/README.md @@ -4,7 +4,7 @@ run them locally by setting appropriate environment variables. ``` $ cd sbt_app_core -$ SCALA_VERSION=2.10.3 \ +$ SCALA_VERSION=2.10.4 \ SPARK_VERSION=1.0.0-SNAPSHOT \ SPARK_RELEASE_REPOSITORY=file:///home/patrick/.ivy2/local \ sbt run diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 52c367d9b030d..fa2f02dfecc75 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -35,7 +35,7 @@ RELEASE_KEY = "9E4FE3AF" RELEASE_REPOSITORY = "https://repository.apache.org/content/repositories/orgapachespark-1006/" RELEASE_VERSION = "1.0.0" -SCALA_VERSION = "2.10.3" +SCALA_VERSION = "2.10.4" SCALA_BINARY_VERSION = "2.10" ## diff --git a/dev/audit-release/maven_app_core/pom.xml b/dev/audit-release/maven_app_core/pom.xml index 0b837c01751fe..76a381f8e17e0 100644 --- a/dev/audit-release/maven_app_core/pom.xml +++ b/dev/audit-release/maven_app_core/pom.xml @@ -49,7 +49,7 @@ maven-compiler-plugin - 2.3.2 + 3.1 diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index 995106f111443..bf1c5d7953bd2 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -49,14 +49,14 @@ mvn -DskipTests \ -Darguments="-DskipTests=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Pspark-ganglia-lgpl \ + -Pyarn -Phive -Pspark-ganglia-lgpl\ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ --batch-mode release:prepare mvn -DskipTests \ -Darguments="-DskipTests=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Pspark-ganglia-lgpl\ release:perform rm -rf spark diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index e8f78fc5f231a..7a61943e94814 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -87,11 +87,20 @@ def merge_pr(pr_num, target_ref): run_cmd("git fetch %s %s:%s" % (PUSH_REMOTE_NAME, target_ref, target_branch_name)) run_cmd("git checkout %s" % target_branch_name) - run_cmd(['git', 'merge', pr_branch_name, '--squash']) + had_conflicts = False + try: + run_cmd(['git', 'merge', pr_branch_name, '--squash']) + except Exception as e: + msg = "Error merging: %s\nWould you like to manually fix-up this merge?" % e + continue_maybe(msg) + msg = "Okay, please fix any conflicts and 'git add' conflicting files... Finished?" + continue_maybe(msg) + had_conflicts = True commit_authors = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%an <%ae>']).split("\n") - distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), reverse=True) + distinct_authors = sorted(set(commit_authors), key=lambda x: commit_authors.count(x), + reverse=True) primary_author = distinct_authors[0] commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n\n") @@ -105,6 +114,13 @@ def merge_pr(pr_num, target_ref): merge_message_flags += ["-m", authors] + if had_conflicts: + committer_name = run_cmd("git config --get user.name").strip() + committer_email = run_cmd("git config --get user.email").strip() + message = "This patch had conflicts when merged, resolved by\nCommitter: %s <%s>" % ( + committer_name, committer_email) + merge_message_flags += ["-m", message] + # The string "Closes #%s" string is required for GitHub to correctly close the PR merge_message_flags += ["-m", "Closes #%s from %s and squashes the following commits:" % (pr_num, pr_repo_desc)] @@ -186,8 +202,10 @@ def maybe_cherry_pick(pr_num, merge_hash, default_branch): maybe_cherry_pick(pr_num, merge_hash, latest_branch) sys.exit(0) -if bool(pr["mergeable"]) == False: - fail("Pull request %s is not mergeable in its current form" % pr_num) +if not bool(pr["mergeable"]): + msg = "Pull request %s is not mergeable in its current form.\n" % pr_num + \ + "Continue? (experts only!)" + continue_maybe(msg) print ("\n=== Pull Request #%s ===" % pr_num) print("title\t%s\nsource\t%s\ntarget\t%s\nurl\t%s" % ( diff --git a/dev/run-tests b/dev/run-tests index a6fcc40a5ba6e..6ad674a2ba127 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -26,13 +26,12 @@ rm -rf ./work # Fail fast set -e - +set -o pipefail if test -x "$JAVA_HOME/bin/java"; then declare java_cmd="$JAVA_HOME/bin/java" else declare java_cmd=java fi - JAVA_VERSION=$($java_cmd -version 2>&1 | sed 's/java version "\(.*\)\.\(.*\)\..*"/\1\2/; 1q') [ "$JAVA_VERSION" -ge 18 ] && echo "" || echo "[Warn] Java 8 tests will not run because JDK version is < 1.8." @@ -49,7 +48,9 @@ dev/scalastyle echo "=========================================================================" echo "Running Spark unit tests" echo "=========================================================================" -sbt/sbt assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" +# echo "q" is needed because sbt on encountering a build file with failure (either resolution or compilation) +# prompts the user for input either q, r, etc to quit or retry. This echo is there to make it not block. +echo -e "q\n" | sbt/sbt assembly test | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including" echo "=========================================================================" echo "Running PySpark tests" @@ -63,5 +64,4 @@ echo "=========================================================================" echo "Detecting binary incompatibilites with MiMa" echo "=========================================================================" ./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore -sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" - +echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving" diff --git a/dev/scalastyle b/dev/scalastyle index 5a18f4d672825..19955b9aaaad3 100755 --- a/dev/scalastyle +++ b/dev/scalastyle @@ -17,8 +17,8 @@ # limitations under the License. # -sbt/sbt clean scalastyle > scalastyle.txt -ERRORS=$(cat scalastyle.txt | grep -e "error file") +echo -e "q\n" | sbt/sbt clean scalastyle > scalastyle.txt +ERRORS=$(cat scalastyle.txt | grep -e "\") if test ! -z "$ERRORS"; then echo -e "Scalastyle checks failed at following occurrences:\n$ERRORS" exit 1 diff --git a/docker/spark-test/base/Dockerfile b/docker/spark-test/base/Dockerfile index e543db6143e4d..5956d59130fbf 100644 --- a/docker/spark-test/base/Dockerfile +++ b/docker/spark-test/base/Dockerfile @@ -25,7 +25,7 @@ RUN apt-get update # install a few other useful packages plus Open Jdk 7 RUN apt-get install -y less openjdk-7-jre-headless net-tools vim-tiny sudo openssh-server -ENV SCALA_VERSION 2.10.3 +ENV SCALA_VERSION 2.10.4 ENV CDH_VERSION cdh4 ENV SCALA_HOME /opt/scala-$SCALA_VERSION ENV SPARK_HOME /opt/spark diff --git a/docs/_config.yml b/docs/_config.yml index aa5a5adbc1743..d585b8c5ea763 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -6,7 +6,7 @@ markdown: kramdown SPARK_VERSION: 1.0.0-SNAPSHOT SPARK_VERSION_SHORT: 1.0.0 SCALA_BINARY_VERSION: "2.10" -SCALA_VERSION: "2.10.3" +SCALA_VERSION: "2.10.4" MESOS_VERSION: 0.13.0 SPARK_ISSUE_TRACKER_URL: https://spark-project.atlassian.net SPARK_GITHUB_URL: https://github.com/apache/spark diff --git a/docs/_plugins/copy_api_dirs.rb b/docs/_plugins/copy_api_dirs.rb index 2245bcbc70f1e..bbd56d2fd13bb 100644 --- a/docs/_plugins/copy_api_dirs.rb +++ b/docs/_plugins/copy_api_dirs.rb @@ -24,7 +24,9 @@ external_projects = ["flume", "kafka", "mqtt", "twitter", "zeromq"] sql_projects = ["catalyst", "core", "hive"] - projects = core_projects + external_projects.map { |project_name| "external/" + project_name } + projects = core_projects + projects = projects + external_projects.map { |project_name| "external/" + project_name } + projects = projects + sql_projects.map { |project_name| "sql/" + project_name } puts "Moving to project root and building scaladoc." curr_dir = pwd @@ -42,24 +44,22 @@ source = "../" + project_name + "/target/scala-2.10/api" dest = "api/" + project_name - puts "echo making directory " + dest + puts "making directory " + dest mkdir_p dest # From the rubydoc: cp_r('src', 'dest') makes src/dest, but this doesn't. puts "cp -r " + source + "/. " + dest cp_r(source + "/.", dest) - end - - sql_projects.each do |project_name| - source = "../sql/" + project_name + "/target/scala-2.10/api/" - dest = "api/sql/" + project_name - puts "echo making directory " + dest - mkdir_p dest + # Append custom JavaScript + js = File.readlines("./js/api-docs.js") + js_file = dest + "/lib/template.js" + File.open(js_file, 'a') { |f| f.write("\n" + js.join()) } - # From the rubydoc: cp_r('src', 'dest') makes src/dest, but this doesn't. - puts "cp -r " + source + "/. " + dest - cp_r(source + "/.", dest) + # Append custom CSS + css = File.readlines("./css/api-docs.css") + css_file = dest + "/lib/template.css" + File.open(css_file, 'a') { |f| f.write("\n" + css.join()) } end # Build Epydoc for Python diff --git a/docs/building-with-maven.md b/docs/building-with-maven.md index 730a6e7932564..9cebaf12283fc 100644 --- a/docs/building-with-maven.md +++ b/docs/building-with-maven.md @@ -6,7 +6,7 @@ title: Building Spark with Maven * This will become a table of contents (this text will be scraped). {:toc} -Building Spark using Maven Requires Maven 3 (the build process is tested with Maven 3.0.4) and Java 1.6 or newer. +Building Spark using Maven requires Maven 3.0.4 or newer and Java 1.6 or newer. ## Setting up Maven's Memory Usage ## diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index b69e3416fb322..7f75ea44e4cea 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -56,7 +56,7 @@ The recommended way to launch a compiled Spark application is through the spark- bin directory), which takes care of setting up the classpath with Spark and its dependencies, as well as provides a layer over the different cluster managers and deploy modes that Spark supports. It's usage is - spark-submit `` `` + spark-submit `` `` Where options are any of: diff --git a/docs/configuration.md b/docs/configuration.md index 1ff0150567255..f3bfd036f4164 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -122,6 +122,21 @@ Apart from these, the following properties are also available, and may be useful spark.storage.memoryFraction. + + spark.tachyonStore.baseDir + System.getProperty("java.io.tmpdir") + + Directories of the Tachyon File System that store RDDs. The Tachyon file system's URL is set by spark.tachyonStore.url. + It can also be a comma-separated list of multiple directories on Tachyon file system. + + + + spark.tachyonStore.url + tachyon://localhost:19998 + + The URL of the underlying Tachyon file system in the TachyonStore. + + spark.mesos.coarse false @@ -161,13 +176,13 @@ Apart from these, the following properties are also available, and may be useful spark.ui.acls.enable false - Whether spark web ui acls should are enabled. If enabled, this checks to see if the user has + Whether spark web ui acls should are enabled. If enabled, this checks to see if the user has access permissions to view the web ui. See spark.ui.view.acls for more details. Also note this requires the user to be known, if the user comes across as null no checks are done. Filters can be used to authenticate and set the user. - + spark.ui.view.acls Empty @@ -175,6 +190,13 @@ Apart from these, the following properties are also available, and may be useful user that started the Spark job has view access. + + spark.ui.killEnabled + true + + Allows stages and corresponding jobs to be killed from the web ui. + + spark.shuffle.compress true @@ -276,10 +298,10 @@ Apart from these, the following properties are also available, and may be useful spark.serializer.objectStreamReset 10000 - When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches - objects to prevent writing redundant data, however that stops garbage collection of those - objects. By calling 'reset' you flush that info from the serializer, and allow old - objects to be collected. To turn off this periodic reset set it to a value of <= 0. + When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches + objects to prevent writing redundant data, however that stops garbage collection of those + objects. By calling 'reset' you flush that info from the serializer, and allow old + objects to be collected. To turn off this periodic reset set it to a value of <= 0. By default it will reset the serializer every 10,000 objects. @@ -333,6 +355,32 @@ Apart from these, the following properties are also available, and may be useful receives no heartbeats. + + spark.worker.cleanup.enabled + true + + Enable periodic cleanup of worker / application directories. Note that this only affects standalone + mode, as YARN works differently. + + + + spark.worker.cleanup.interval + 1800 (30 minutes) + + Controls the interval, in seconds, at which the worker cleans up old application work dirs + on the local machine. + + + + spark.worker.cleanup.appDataTtl + 7 * 24 * 3600 (7 days) + + The number of seconds to retain application work directories on each worker. This is a Time To Live + and should depend on the amount of available disk space you have. Application logs and jars are + downloaded to each application work dir. Over time, the work dirs can quickly fill up disk space, + especially if you run jobs very frequently. + + spark.akka.frameSize 10 @@ -375,7 +423,7 @@ Apart from these, the following properties are also available, and may be useful spark.akka.heartbeat.interval 1000 - This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using failure detector can be, a sensistive failure detector can help evict rogue executors really quick. However this is usually not the case as gc pauses and network lags are expected in a real spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those. + This is set to a larger value to disable failure detector that comes inbuilt akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using failure detector can be, a sensistive failure detector can help evict rogue executors really quick. However this is usually not the case as gc pauses and network lags are expected in a real spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats between nodes leading to flooding the network with those. @@ -430,7 +478,7 @@ Apart from these, the following properties are also available, and may be useful spark.broadcast.blockSize 4096 - Size of each piece of a block in kilobytes for TorrentBroadcastFactory. + Size of each piece of a block in kilobytes for TorrentBroadcastFactory. Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit. @@ -555,7 +603,16 @@ Apart from these, the following properties are also available, and may be useful the driver. - + + spark.files.userClassPathFirst + false + + (Experimental) Whether to give user-added jars precedence over Spark's own jars when + loading classes in Executors. This feature can be used to mitigate conflicts between + Spark's dependencies and user dependencies. It is currently an experimental feature. + + + spark.authenticate false @@ -563,7 +620,7 @@ Apart from these, the following properties are also available, and may be useful running on Yarn. - + spark.authenticate.secret None @@ -571,12 +628,12 @@ Apart from these, the following properties are also available, and may be useful not running on Yarn and authentication is enabled. - + spark.core.connection.auth.wait.timeout 30 Number of seconds for the connection to wait for authentication to occur before timing - out and giving up. + out and giving up. diff --git a/docs/css/api-docs.css b/docs/css/api-docs.css new file mode 100644 index 0000000000000..b2d1d7f869790 --- /dev/null +++ b/docs/css/api-docs.css @@ -0,0 +1,18 @@ +/* Dynamically injected style for the API docs */ + +.developer { + background-color: #44751E; +} + +.experimental { + background-color: #257080; +} + +.alphaComponent { + background-color: #bb0000; +} + +.badge { + font-family: Arial, san-serif; + float: right; +} diff --git a/docs/index.md b/docs/index.md index 7a13fa9a9a2b6..89ec5b05488a9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -67,8 +67,6 @@ In addition, if you wish to run Spark on [YARN](running-on-yarn.html), set Note that on Windows, you need to set the environment variables on separate lines, e.g., `set SPARK_HADOOP_VERSION=1.2.1`. -For this version of Spark (0.8.1) Hadoop 2.2.x (or newer) users will have to build Spark and publish it locally. See [Launching Spark on YARN](running-on-yarn.html). This is needed because Hadoop 2.2 has non backwards compatible API changes. - # Where to Go from Here **Programming guides:** diff --git a/docs/js/api-docs.js b/docs/js/api-docs.js new file mode 100644 index 0000000000000..1414b6d0b81a1 --- /dev/null +++ b/docs/js/api-docs.js @@ -0,0 +1,26 @@ +/* Dynamically injected post-processing code for the API docs */ + +$(document).ready(function() { + var annotations = $("dt:contains('Annotations')").next("dd").children("span.name"); + addBadges(annotations, "AlphaComponent", ":: AlphaComponent ::", "Alpha Component"); + addBadges(annotations, "DeveloperApi", ":: DeveloperApi ::", "Developer API"); + addBadges(annotations, "Experimental", ":: Experimental ::", "Experimental"); +}); + +function addBadges(allAnnotations, name, tag, html) { + var annotations = allAnnotations.filter(":contains('" + name + "')") + var tags = $(".cmt:contains(" + tag + ")") + + // Remove identifier tags from comments + tags.each(function(index) { + var oldHTML = $(this).html(); + var newHTML = oldHTML.replace(tag, ""); + $(this).html(newHTML); + }); + + // Add badges to all containers + tags.prevAll("h4.signature") + .add(annotations.closest("div.fullcommenttop")) + .add(annotations.closest("div.fullcomment").prevAll("h4.signature")) + .prepend(html); +} diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md index 203d235bf9663..a5e0cc50809cf 100644 --- a/docs/mllib-guide.md +++ b/docs/mllib-guide.md @@ -38,6 +38,5 @@ depends on native Fortran routines. You may need to install the if it is not already present on your nodes. MLlib will throw a linking error if it cannot detect these libraries automatically. -To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.7 or newer -and Python 2.7. +To use MLlib in Python, you will need [NumPy](http://www.numpy.org) version 1.7 or newer. diff --git a/docs/monitoring.md b/docs/monitoring.md index 15bfb041780da..4c91c3a5929bf 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -12,17 +12,77 @@ displays useful information about the application. This includes: * A list of scheduler stages and tasks * A summary of RDD sizes and memory usage -* Information about the running executors * Environmental information. +* Information about the running executors You can access this interface by simply opening `http://:4040` in a web browser. -If multiple SparkContexts are running on the same host, they will bind to succesive ports +If multiple SparkContexts are running on the same host, they will bind to successive ports beginning with 4040 (4041, 4042, etc). -Spark's Standalone Mode cluster manager also has its own -[web UI](spark-standalone.html#monitoring-and-logging). +Note that this information is only available for the duration of the application by default. +To view the web UI after the fact, set `spark.eventLog.enabled` to true before starting the +application. This configures Spark to log Spark events that encode the information displayed +in the UI to persisted storage. -Note that in both of these UIs, the tables are sortable by clicking their headers, +## Viewing After the Fact + +Spark's Standalone Mode cluster manager also has its own +[web UI](spark-standalone.html#monitoring-and-logging). If an application has logged events over +the course of its lifetime, then the Standalone master's web UI will automatically re-render the +application's UI after the application has finished. + +If Spark is run on Mesos or YARN, it is still possible to reconstruct the UI of a finished +application through Spark's history server, provided that the application's event logs exist. +You can start a the history server by executing: + + ./sbin/start-history-server.sh + +The base logging directory must be supplied, and should contain sub-directories that each +represents an application's event logs. This creates a web interface at +`http://:18080` by default. The history server depends on the following variables: + + + + + + + + + + + +
      Environment VariableMeaning
      SPARK_DAEMON_MEMORYMemory to allocate to the history server. (default: 512m).
      SPARK_DAEMON_JAVA_OPTSJVM options for the history server (default: none).
      + +Further, the history server can be configured as follows: + + + + + + + + + + + + + + + + + + +
      Property NameDefaultMeaning
      spark.history.updateInterval10 + The period, in seconds, at which information displayed by this history server is updated. + Each update checks for any changes made to the event logs in persisted storage. +
      spark.history.retainedApplications250 + The number of application UIs to retain. If this cap is exceeded, then the oldest + applications will be removed. +
      spark.history.ui.port18080 + The port to which the web interface of the history server binds. +
      + +Note that in all of these UIs, the tables are sortable by clicking their headers, making it easy to identify slow tasks, data skew, etc. # Metrics diff --git a/docs/python-programming-guide.md b/docs/python-programming-guide.md index cbe7d820b455e..888631e7025b0 100644 --- a/docs/python-programming-guide.md +++ b/docs/python-programming-guide.md @@ -82,15 +82,16 @@ The Python shell can be used explore data interactively and is a simple way to l >>> help(pyspark) # Show all pyspark functions {% endhighlight %} -By default, the `bin/pyspark` shell creates SparkContext that runs applications locally on a single core. -To connect to a non-local cluster, or use multiple cores, set the `MASTER` environment variable. +By default, the `bin/pyspark` shell creates SparkContext that runs applications locally on all of +your machine's logical cores. +To connect to a non-local cluster, or to specify a number of cores, set the `MASTER` environment variable. For example, to use the `bin/pyspark` shell with a [standalone Spark cluster](spark-standalone.html): {% highlight bash %} $ MASTER=spark://IP:PORT ./bin/pyspark {% endhighlight %} -Or, to use four cores on the local machine: +Or, to use exactly four cores on the local machine: {% highlight bash %} $ MASTER=local[4] ./bin/pyspark @@ -152,7 +153,7 @@ Many of the methods also contain [doctests](http://docs.python.org/2/library/doc # Libraries [MLlib](mllib-guide.html) is also available in PySpark. To use it, you'll need -[NumPy](http://www.numpy.org) version 1.7 or newer, and Python 2.7. The [MLlib guide](mllib-guide.html) contains +[NumPy](http://www.numpy.org) version 1.7 or newer. The [MLlib guide](mllib-guide.html) contains some example applications. # Where to Go from Here diff --git a/docs/quick-start.md b/docs/quick-start.md index 13df6beea16e8..60e8b1ba0eb46 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -124,7 +124,7 @@ object SimpleApp { } {% endhighlight %} -This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the proogram. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the application, the directory where Spark is installed, and a name for the jar file containing the application's code. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes. +This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace $YOUR_SPARK_HOME with the location where Spark is installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, we initialize a SparkContext as part of the program. We pass the SparkContext constructor four arguments, the type of scheduler we want to use (in this case, a local scheduler), a name for the application, the directory where Spark is installed, and a name for the jar file containing the application's code. The final two arguments are needed in a distributed setting, where Spark is running across several nodes, so we include them for completeness. Spark will automatically ship the jar files you list to slave nodes. This file depends on the Spark API, so we'll also include an sbt configuration file, `simple.sbt` which explains that Spark is a dependency. This file also adds a repository that Spark depends on: diff --git a/docs/scala-programming-guide.md b/docs/scala-programming-guide.md index 99412733d4268..a07cd2e0a32a2 100644 --- a/docs/scala-programming-guide.md +++ b/docs/scala-programming-guide.md @@ -23,7 +23,7 @@ To write a Spark application, you need to add a dependency on Spark. If you use groupId = org.apache.spark artifactId = spark-core_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION}} + version = {{site.SPARK_VERSION}} In addition, if you wish to access an HDFS cluster, you need to add a dependency on `hadoop-client` for your version of HDFS: @@ -54,7 +54,7 @@ object for more advanced configuration. The `master` parameter is a string specifying a [Spark or Mesos cluster URL](#master-urls) to connect to, or a special "local" string to run in local mode, as described below. `appName` is a name for your application, which will be shown in the cluster web UI. Finally, the last two parameters are needed to deploy your code to a cluster if running in distributed mode, as described later. -In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `bin/spark-shell` on four cores, use +In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `bin/spark-shell` on exactly four cores, use {% highlight bash %} $ MASTER=local[4] ./bin/spark-shell @@ -73,18 +73,19 @@ The master URL passed to Spark can be in one of the following formats: - - -
      Master URLMeaning
      local Run Spark locally with one worker thread (i.e. no parallelism at all).
      local[K] Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine). +
      local[K] Run Spark locally with K worker threads (ideally, set this to the number of cores on your machine). +
      local[*] Run Spark locally with as many worker threads as logical cores on your machine.
      spark://HOST:PORT Connect to the given Spark standalone - cluster master. The port must be whichever one your master is configured to use, which is 7077 by default. +
      spark://HOST:PORT Connect to the given Spark standalone + cluster master. The port must be whichever one your master is configured to use, which is 7077 by default.
      mesos://HOST:PORT Connect to the given Mesos cluster. - The host parameter is the hostname of the Mesos master. The port must be whichever one the master is configured to use, - which is 5050 by default. +
      mesos://HOST:PORT Connect to the given Mesos cluster. + The host parameter is the hostname of the Mesos master. The port must be whichever one the master is configured to use, + which is 5050 by default.
      -If no master URL is specified, the spark shell defaults to "local". +If no master URL is specified, the spark shell defaults to "local[*]". For running on YARN, Spark launches an instance of the standalone deploy cluster within YARN; see [running on YARN](running-on-yarn.html) for details. @@ -265,11 +266,25 @@ A complete list of actions is available in the [RDD API doc](api/core/index.html ## RDD Persistence -One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory across operations. When you persist an RDD, each node stores any slices of it that it computes in memory and reuses them in other actions on that dataset (or datasets derived from it). This allows future actions to be much faster (often by more than 10x). Caching is a key tool for building iterative algorithms with Spark and for interactive use from the interpreter. - -You can mark an RDD to be persisted using the `persist()` or `cache()` methods on it. The first time it is computed in an action, it will be kept in memory on the nodes. The cache is fault-tolerant -- if any partition of an RDD is lost, it will automatically be recomputed using the transformations that originally created it. - -In addition, each RDD can be stored using a different *storage level*, allowing you, for example, to persist the dataset on disk, or persist it in memory but as serialized Java objects (to save space), or even replicate it across nodes. These levels are chosen by passing a [`org.apache.spark.storage.StorageLevel`](api/core/index.html#org.apache.spark.storage.StorageLevel) object to `persist()`. The `cache()` method is a shorthand for using the default storage level, which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The complete set of available storage levels is: +One of the most important capabilities in Spark is *persisting* (or *caching*) a dataset in memory +across operations. When you persist an RDD, each node stores any slices of it that it computes in +memory and reuses them in other actions on that dataset (or datasets derived from it). This allows +future actions to be much faster (often by more than 10x). Caching is a key tool for building +iterative algorithms with Spark and for interactive use from the interpreter. + +You can mark an RDD to be persisted using the `persist()` or `cache()` methods on it. The first time +it is computed in an action, it will be kept in memory on the nodes. The cache is fault-tolerant -- +if any partition of an RDD is lost, it will automatically be recomputed using the transformations +that originally created it. + +In addition, each RDD can be stored using a different *storage level*, allowing you, for example, to +persist the dataset on disk, or persist it in memory but as serialized Java objects (to save space), +or replicate it across nodes, or store the data in off-heap memory in [Tachyon](http://tachyon-project.org/). +These levels are chosen by passing a +[`org.apache.spark.storage.StorageLevel`](api/core/index.html#org.apache.spark.storage.StorageLevel) +object to `persist()`. The `cache()` method is a shorthand for using the default storage level, +which is `StorageLevel.MEMORY_ONLY` (store deserialized objects in memory). The complete set of +available storage levels is: @@ -292,8 +307,16 @@ In addition, each RDD can be stored using a different *storage level*, allowing - + + + + + @@ -307,30 +330,59 @@ In addition, each RDD can be stored using a different *storage level*, allowing ### Which Storage Level to Choose? -Spark's storage levels are meant to provide different tradeoffs between memory usage and CPU efficiency. -We recommend going through the following process to select one: - -* If your RDDs fit comfortably with the default storage level (`MEMORY_ONLY`), leave them that way. This is the most - CPU-efficient option, allowing operations on the RDDs to run as fast as possible. -* If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to make the objects - much more space-efficient, but still reasonably fast to access. -* Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large - amount of the data. Otherwise, recomputing a partition is about as fast as reading it from disk. -* Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve requests from a web - application). *All* the storage levels provide full fault tolerance by recomputing lost data, but the replicated ones - let you continue running tasks on the RDD without waiting to recompute a lost partition. - -If you want to define your own storage level (say, with replication factor of 3 instead of 2), then use the function factor method `apply()` of the [`StorageLevel`](api/core/index.html#org.apache.spark.storage.StorageLevel$) singleton object. +Spark's storage levels are meant to provide different trade-offs between memory usage and CPU +efficiency. It allows uses to choose memory, disk, or Tachyon for storing data. We recommend going +through the following process to select one: + +* If your RDDs fit comfortably with the default storage level (`MEMORY_ONLY`), leave them that way. + This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. + +* If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to +make the objects much more space-efficient, but still reasonably fast to access. You can also use +`OFF_HEAP` mode to store the data off the heap in [Tachyon](http://tachyon-project.org/). This will +significantly reduce JVM GC overhead. + +* Don't spill to disk unless the functions that computed your datasets are expensive, or they filter +a large amount of the data. Otherwise, recomputing a partition is about as fast as reading it from +disk. + +* Use the replicated storage levels if you want fast fault recovery (e.g. if using Spark to serve +requests from a web application). *All* the storage levels provide full fault tolerance by +recomputing lost data, but the replicated ones let you continue running tasks on the RDD without +waiting to recompute a lost partition. + +If you want to define your own storage level (say, with replication factor of 3 instead of 2), then +use the function factor method `apply()` of the +[`StorageLevel`](api/core/index.html#org.apache.spark.storage.StorageLevel$) singleton object. + +Spark has a block manager inside the Executors that let you chose memory, disk, or off-heap. The +latter is for storing RDDs off-heap outside the Executor JVM on top of the memory management system +[Tachyon](http://tachyon-project.org/). This mode has the following advantages: + +* Cached data will not be lost if individual executors crash. +* Executors can have a smaller memory footprint, allowing you to run more executors on the same +machine as the bulk of the memory will be inside Tachyon. +* Reduced GC overhead since data is stored in Tachyon. # Shared Variables -Normally, when a function passed to a Spark operation (such as `map` or `reduce`) is executed on a remote cluster node, it works on separate copies of all the variables used in the function. These variables are copied to each machine, and no updates to the variables on the remote machine are propagated back to the driver program. Supporting general, read-write shared variables across tasks would be inefficient. However, Spark does provide two limited types of *shared variables* for two common usage patterns: broadcast variables and accumulators. +Normally, when a function passed to a Spark operation (such as `map` or `reduce`) is executed on a +remote cluster node, it works on separate copies of all the variables used in the function. These +variables are copied to each machine, and no updates to the variables on the remote machine are +propagated back to the driver program. Supporting general, read-write shared variables across tasks +would be inefficient. However, Spark does provide two limited types of *shared variables* for two +common usage patterns: broadcast variables and accumulators. ## Broadcast Variables -Broadcast variables allow the programmer to keep a read-only variable cached on each machine rather than shipping a copy of it with tasks. They can be used, for example, to give every node a copy of a large input dataset in an efficient manner. Spark also attempts to distribute broadcast variables using efficient broadcast algorithms to reduce communication cost. +Broadcast variables allow the programmer to keep a read-only variable cached on each machine rather +than shipping a copy of it with tasks. They can be used, for example, to give every node a copy of a +large input dataset in an efficient manner. Spark also attempts to distribute broadcast variables +using efficient broadcast algorithms to reduce communication cost. -Broadcast variables are created from a variable `v` by calling `SparkContext.broadcast(v)`. The broadcast variable is a wrapper around `v`, and its value can be accessed by calling the `value` method. The interpreter session below shows this: +Broadcast variables are created from a variable `v` by calling `SparkContext.broadcast(v)`. The +broadcast variable is a wrapper around `v`, and its value can be accessed by calling the `value` +method. The interpreter session below shows this: {% highlight scala %} scala> val broadcastVar = sc.broadcast(Array(1, 2, 3)) @@ -340,13 +392,21 @@ scala> broadcastVar.value res0: Array[Int] = Array(1, 2, 3) {% endhighlight %} -After the broadcast variable is created, it should be used instead of the value `v` in any functions run on the cluster so that `v` is not shipped to the nodes more than once. In addition, the object `v` should not be modified after it is broadcast in order to ensure that all nodes get the same value of the broadcast variable (e.g. if the variable is shipped to a new node later). +After the broadcast variable is created, it should be used instead of the value `v` in any functions +run on the cluster so that `v` is not shipped to the nodes more than once. In addition, the object +`v` should not be modified after it is broadcast in order to ensure that all nodes get the same +value of the broadcast variable (e.g. if the variable is shipped to a new node later). ## Accumulators -Accumulators are variables that are only "added" to through an associative operation and can therefore be efficiently supported in parallel. They can be used to implement counters (as in MapReduce) or sums. Spark natively supports accumulators of numeric value types and standard mutable collections, and programmers can add support for new types. +Accumulators are variables that are only "added" to through an associative operation and can +therefore be efficiently supported in parallel. They can be used to implement counters (as in +MapReduce) or sums. Spark natively supports accumulators of numeric value types and standard mutable +collections, and programmers can add support for new types. -An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks running on the cluster can then add to it using the `+=` operator. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method. +An accumulator is created from an initial value `v` by calling `SparkContext.accumulator(v)`. Tasks +running on the cluster can then add to it using the `+=` operator. However, they cannot read its +value. Only the driver program can read the accumulator's value, using its `value` method. The interpreter session below shows an accumulator being used to add up the elements of an array: diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b6f21a5dc62c3..a59393e1424de 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -8,6 +8,10 @@ title: Spark SQL Programming Guide {:toc} # Overview + +
      +
      + Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using Spark. At the core of this component is a new type of RDD, [SchemaRDD](api/sql/core/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed @@ -18,11 +22,27 @@ file, or by running HiveQL against data stored in [Apache Hive](http://hive.apac **All of the examples on this page use sample data included in the Spark distribution and can be run in the spark-shell.** +
      + +
      +Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using +Spark. At the core of this component is a new type of RDD, +[JavaSchemaRDD](api/sql/core/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed +[Row](api/sql/catalyst/index.html#org.apache.spark.sql.api.java.Row) objects along with +a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table +in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, parquet +file, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/). +
      +
      + *************************************************************************************************** # Getting Started -The entry point into all relational functionallity in Spark is the +
      +
      + +The entry point into all relational functionality in Spark is the [SQLContext](api/sql/core/index.html#org.apache.spark.sql.SQLContext) class, or one of its decendents. To create a basic SQLContext, all you need is a SparkContext. @@ -34,8 +54,30 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext._ {% endhighlight %} +
      + +
      + +The entry point into all relational functionality in Spark is the +[JavaSQLContext](api/sql/core/index.html#org.apache.spark.sql.api.java.JavaSQLContext) class, or one +of its decendents. To create a basic JavaSQLContext, all you need is a JavaSparkContext. + +{% highlight java %} +JavaSparkContext ctx = ...; // An existing JavaSparkContext. +JavaSQLContext sqlCtx = new org.apache.spark.sql.api.java.JavaSQLContext(ctx); +{% endhighlight %} + +
      + +
      + ## Running SQL on RDDs -One type of table that is supported by Spark SQL is an RDD of Scala case classetees. The case class + +
      + +
      + +One type of table that is supported by Spark SQL is an RDD of Scala case classes. The case class defines the schema of the table. The names of the arguments to the case class are read using reflection and become the names of the columns. Case classes can also be nested or contain complex types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be @@ -60,7 +102,83 @@ val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} -**Note that Spark SQL currently uses a very basic SQL parser, and the keywords are case sensitive.** +
      + +
      + +One type of table that is supported by Spark SQL is an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly). The BeanInfo +defines the schema of the table. Currently, Spark SQL does not support JavaBeans that contain +nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a +class that implements Serializable and has getters and setters for all of its fields. + +{% highlight java %} + +public static class Person implements Serializable { + private String name; + private int age; + + String getName() { + return name; + } + + void setName(String name) { + this.name = name; + } + + int getAge() { + return age; + } + + void setAge(int age) { + this.age = age; + } +} + +{% endhighlight %} + + +A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object +for the JavaBean. + +{% highlight java %} +JavaSQLContext ctx = new org.apache.spark.sql.api.java.JavaSQLContext(sc) + +// Load a text file and convert each line to a JavaBean. +JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map( + new Function() { + public Person call(String line) throws Exception { + String[] parts = line.split(","); + + Person person = new Person(); + person.setName(parts[0]); + person.setAge(Integer.parseInt(parts[1].trim())); + + return person; + } + }); + +// Apply a schema to an RDD of JavaBeans and register it as a table. +JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); +schemaPeople.registerAsTable("people"); + +// SQL can be run over RDDs that have been registered as tables. +JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") + +// The results of SQL queries are SchemaRDDs and support all the normal RDD operations. +// The columns of a row in the result can be accessed by ordinal. +List teenagerNames = teenagers.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0); + } +}).collect(); + +{% endhighlight %} + +
      + +
      + +**Note that Spark SQL currently uses a very basic SQL parser.** Users that want a more complete dialect of SQL should look at the HiveQL support provided by `HiveContext`. @@ -70,17 +188,21 @@ Parquet is a columnar format that is supported by many other data processing sys provides support for both reading and writing parquet files that automatically preserves the schema of the original data. Using the data from the above example: +
      + +
      + {% highlight scala %} val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext._ -val people: RDD[Person] // An RDD of case class objects, from the previous example. +val people: RDD[Person] = ... // An RDD of case class objects, from the previous example. // The RDD is implicitly converted to a SchemaRDD, allowing it to be stored using parquet. people.saveAsParquetFile("people.parquet") // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. -// The result of loading a parquet file is also a SchemaRDD. +// The result of loading a parquet file is also a JavaSchemaRDD. val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. @@ -89,15 +211,43 @@ val teenagers = sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19" teenagers.collect().foreach(println) {% endhighlight %} +
      + +
      + +{% highlight java %} + +JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example. + +// JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. +schemaPeople.saveAsParquetFile("people.parquet"); + +// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. +// The result of loading a parquet file is also a JavaSchemaRDD. +JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + +//Parquet files can also be registered as tables and then used in SQL statements. +parquetFile.registerAsTable("parquetFile"); +JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + + +{% endhighlight %} + +
      + +
      + ## Writing Language-Integrated Relational Queries +**Language-Integrated queries are currently only supported in Scala.** + Spark SQL also supports a domain specific language for writing queries. Once again, using the data from the above examples: {% highlight scala %} val sqlContext = new org.apache.spark.sql.SQLContext(sc) import sqlContext._ -val people: RDD[Person] // An RDD of case class objects, from the first example. +val people: RDD[Person] = ... // An RDD of case class objects, from the first example. // The following is the same as 'SELECT name FROM people WHERE age >= 10 AND age <= 19' val teenagers = people.where('age >= 10).where('age <= 19).select('name) @@ -114,14 +264,17 @@ evaluated by the SQL execution engine. A full list of the functions supported c Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. -In order to use Hive you must first run '`sbt/sbt hive/assembly`'. This command builds a new assembly -jar that includes Hive. When this jar is present, Spark will use the Hive -assembly instead of the normal Spark assembly. Note that this Hive assembly jar must also be present +In order to use Hive you must first run '`SPARK_HIVE=true sbt/sbt assembly/assembly`' (or use `-Phive` for maven). +This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to acccess data stored in Hive. Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`. +
      + +
      + When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and adds support for finding tables in in the MetaStore and writing queries using HiveQL. Users who do not have an existing Hive deployment can also experiment with the `LocalHiveContext`, @@ -135,9 +288,34 @@ val hiveContext = new org.apache.spark.sql.hive.HiveContext(sc) // Importing the SQL context gives access to all the public SQL functions and implicit conversions. import hiveContext._ -sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") -sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") +hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") +hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL -sql("SELECT key, value FROM src").collect().foreach(println) -{% endhighlight %} \ No newline at end of file +hql("FROM src SELECT key, value").collect().foreach(println) +{% endhighlight %} + +
      + +
      + +When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and +adds support for finding tables in in the MetaStore and writing queries using HiveQL. In addition to +the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be +expressed in HiveQL. + +{% highlight java %} +JavaSparkContext ctx = ...; // An existing JavaSparkContext. +JavaHiveContext hiveCtx = new org.apache.spark.sql.hive.api.java.HiveContext(ctx); + +hiveCtx.hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)"); +hiveCtx.hql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src"); + +// Queries are expressed in HiveQL. +Row[] results = hiveCtx.hql("FROM src SELECT key, value").collect(); + +{% endhighlight %} + +
      + +
      diff --git a/docs/tuning.md b/docs/tuning.md index 093df3187a789..cc069f0e84b9c 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -90,9 +90,10 @@ than the "raw" data inside their fields. This is due to several reasons: * Each distinct Java object has an "object header", which is about 16 bytes and contains information such as a pointer to its class. For an object with very little data in it (say one `Int` field), this can be bigger than the data. -* Java Strings have about 40 bytes of overhead over the raw string data (since they store it in an +* Java `String`s have about 40 bytes of overhead over the raw string data (since they store it in an array of `Char`s and keep extra data such as the length), and store each character - as *two* bytes due to Unicode. Thus a 10-character string can easily consume 60 bytes. + as *two* bytes due to `String`'s internal usage of UTF-16 encoding. Thus a 10-character string can + easily consume 60 bytes. * Common collection classes, such as `HashMap` and `LinkedList`, use linked data structures, where there is a "wrapper" object for each entry (e.g. `Map.Entry`). This object not only has a header, but also pointers (typically 8 bytes each) to the next object in the list. diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index d8840c94ac17c..31209a662bbe1 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -70,7 +70,7 @@ def parse_args(): "slaves across multiple (an additional $0.01/Gb for bandwidth" + "between zones applies)") parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use") - parser.add_option("-v", "--spark-version", default="0.9.0", + parser.add_option("-v", "--spark-version", default="0.9.1", help="Version of Spark to use: 'X.Y.Z' or a specific git hash") parser.add_option("--spark-git-repo", default="https://github.com/apache/spark", @@ -157,7 +157,7 @@ def is_active(instance): # Return correct versions of Spark and Shark, given the supplied Spark version def get_spark_shark_version(opts): - spark_shark_map = {"0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0"} + spark_shark_map = {"0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0", "0.9.1": "0.9.1"} version = opts.spark_version.replace("v", "") if version not in spark_shark_map: print >> stderr, "Don't know about Spark version: %s" % version diff --git a/examples/pom.xml b/examples/pom.xml index a5569ff5e71f3..0b6212b5d1549 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -110,7 +110,7 @@ org.apache.hbase hbase - 0.94.6 + ${hbase.version} asm diff --git a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java index 6b49244ba459d..bd96274021756 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaHdfsLR.java @@ -138,6 +138,6 @@ public static void main(String[] args) { System.out.print("Final w: "); printWeights(w); - System.exit(0); + sc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 617e4a6d045e0..2a4278d3c30e5 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -126,6 +126,6 @@ public Stats call(Stats stats, Stats stats2) { for (Tuple2 t : output) { System.out.println(t._1() + "\t" + t._2()); } - System.exit(0); + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java index eb70fb547564c..e31f676f5fd4c 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java @@ -17,7 +17,10 @@ package org.apache.spark.examples; + import scala.Tuple2; + +import com.google.common.collect.Iterables; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -26,8 +29,9 @@ import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; -import java.util.List; import java.util.ArrayList; +import java.util.List; +import java.util.Iterator; import java.util.regex.Pattern; /** @@ -66,7 +70,7 @@ public static void main(String[] args) throws Exception { JavaRDD lines = ctx.textFile(args[1], 1); // Loads all URLs from input file and initialize their neighbors. - JavaPairRDD> links = lines.mapToPair(new PairFunction() { + JavaPairRDD> links = lines.mapToPair(new PairFunction() { @Override public Tuple2 call(String s) { String[] parts = SPACES.split(s); @@ -75,9 +79,9 @@ public Tuple2 call(String s) { }).distinct().groupByKey().cache(); // Loads all URLs with other URL(s) link to from input file and initialize ranks of them to one. - JavaPairRDD ranks = links.mapValues(new Function, Double>() { + JavaPairRDD ranks = links.mapValues(new Function, Double>() { @Override - public Double call(List rs) { + public Double call(Iterable rs) { return 1.0; } }); @@ -86,12 +90,13 @@ public Double call(List rs) { for (int current = 0; current < Integer.parseInt(args[2]); current++) { // Calculates URL contributions to the rank of other URLs. JavaPairRDD contribs = links.join(ranks).values() - .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { + .flatMapToPair(new PairFlatMapFunction, Double>, String, Double>() { @Override - public Iterable> call(Tuple2, Double> s) { + public Iterable> call(Tuple2, Double> s) { + int urlCount = Iterables.size(s._1); List> results = new ArrayList>(); - for (String n : s._1()) { - results.add(new Tuple2(n, s._2() / s._1().size())); + for (String n : s._1) { + results.add(new Tuple2(n, s._2() / urlCount)); } return results; } @@ -112,6 +117,6 @@ public Double call(Double sum) { System.out.println(tuple._1() + " has rank: " + tuple._2() + "."); } - System.exit(0); + ctx.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaTC.java b/examples/src/main/java/org/apache/spark/examples/JavaTC.java index 6cfe25c80ecc6..1d776940f06c6 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaTC.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaTC.java @@ -96,6 +96,6 @@ public Tuple2 call(Tuple2 e) { } while (nextCount != oldCount); System.out.println("TC has " + tc.count() + " edges."); - System.exit(0); + sc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java index 3ae1d8f7ca938..87c1b80981961 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaWordCount.java @@ -48,14 +48,14 @@ public Iterable call(String s) { return Arrays.asList(SPACE.split(s)); } }); - + JavaPairRDD ones = words.mapToPair(new PairFunction() { @Override public Tuple2 call(String s) { return new Tuple2(s, 1); } }); - + JavaPairRDD counts = ones.reduceByKey(new Function2() { @Override public Integer call(Integer i1, Integer i2) { @@ -67,6 +67,6 @@ public Integer call(Integer i1, Integer i2) { for (Tuple2 tuple : output) { System.out.println(tuple._1() + ": " + tuple._2()); } - System.exit(0); + ctx.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java new file mode 100644 index 0000000000000..e8e63d2745692 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.sql; + +import java.io.Serializable; +import java.util.List; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; + +import org.apache.spark.sql.api.java.JavaSQLContext; +import org.apache.spark.sql.api.java.JavaSchemaRDD; +import org.apache.spark.sql.api.java.Row; + +public class JavaSparkSQL { + public static class Person implements Serializable { + private String name; + private int age; + + String getName() { + return name; + } + + void setName(String name) { + this.name = name; + } + + int getAge() { + return age; + } + + void setAge(int age) { + this.age = age; + } + } + + public static void main(String[] args) throws Exception { + JavaSparkContext ctx = new JavaSparkContext("local", "JavaSparkSQL", + System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaSparkSQL.class)); + JavaSQLContext sqlCtx = new JavaSQLContext(ctx); + + // Load a text file and convert each line to a Java Bean. + JavaRDD people = ctx.textFile("examples/src/main/resources/people.txt").map( + new Function() { + public Person call(String line) throws Exception { + String[] parts = line.split(","); + + Person person = new Person(); + person.setName(parts[0]); + person.setAge(Integer.parseInt(parts[1].trim())); + + return person; + } + }); + + // Apply a schema to an RDD of Java Beans and register it as a table. + JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); + schemaPeople.registerAsTable("people"); + + // SQL can be run over RDDs that have been registered as tables. + JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + + // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. + // The columns of a row in the result can be accessed by ordinal. + List teenagerNames = teenagers.map(new Function() { + public String call(Row row) { + return "Name: " + row.getString(0); + } + }).collect(); + + // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. + schemaPeople.saveAsParquetFile("people.parquet"); + + // Read in the parquet file created above. Parquet files are self-describing so the schema is preserved. + // The result of loading a parquet file is also a JavaSchemaRDD. + JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + + //Parquet files can also be registered as tables and then used in SQL statements. + parquetFile.registerAsTable("parquetFile"); + JavaSchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); + } +} diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java index 64a3a04fb7296..c516199d61c72 100644 --- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaALS.java @@ -85,6 +85,6 @@ public static void main(String[] args) { outputDir + "/productFeatures"); System.out.println("Final user/product features written to " + outputDir); - System.exit(0); + sc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java index 7b0ec36424e97..7461609ab9e8f 100644 --- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java @@ -79,6 +79,6 @@ public static void main(String[] args) { double cost = model.computeCost(points.rdd()); System.out.println("Cost: " + cost); - System.exit(0); + sc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java index 667c72f379e71..e3ab87cc722f3 100644 --- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaLR.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.examples; +import java.util.regex.Pattern; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; @@ -24,11 +25,9 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithSGD; import org.apache.spark.mllib.classification.LogisticRegressionModel; +import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import java.util.Arrays; -import java.util.regex.Pattern; - /** * Logistic regression based classification using ML Lib. */ @@ -47,14 +46,10 @@ public LabeledPoint call(String line) { for (int i = 0; i < tok.length; ++i) { x[i] = Double.parseDouble(tok[i]); } - return new LabeledPoint(y, x); + return new LabeledPoint(y, Vectors.dense(x)); } } - public static void printWeights(double[] a) { - System.out.println(Arrays.toString(a)); - } - public static void main(String[] args) { if (args.length != 4) { System.err.println("Usage: JavaLR "); @@ -80,9 +75,8 @@ public static void main(String[] args) { LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(), iterations, stepSize); - System.out.print("Final w: "); - printWeights(model.weights()); + System.out.print("Final w: " + model.weights()); - System.exit(0); + sc.stop(); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 4d2f45df85fc6..c8c916bb45e00 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -56,6 +56,6 @@ object BroadcastTest { println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) } - System.exit(0) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index ee283ce6abac2..1f8d7cb5995b8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -58,7 +58,7 @@ import org.apache.spark.SparkContext._ prod_id, quantity) VALUES ('charlie', 1385983649000, 'iphone', 2); */ - + /** * This example demonstrates how to read and write to cassandra column family created using CQL3 * using Spark. diff --git a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala index fdb976dfc6aba..be7d39549a28d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ExceptionHandlingTest.scala @@ -34,6 +34,6 @@ object ExceptionHandlingTest { } } - System.exit(0) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 36534e59353cd..29114c6dabcdb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -28,7 +28,7 @@ object GroupByTest { "Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") System.exit(1) } - + var numMappers = if (args.length > 1) args(1).toInt else 2 var numKVPairs = if (args.length > 2) args(2).toInt else 1000 var valSize = if (args.length > 3) args(3).toInt else 1000 @@ -52,7 +52,6 @@ object GroupByTest { println(pairs1.groupByKey(numReducers).count) - System.exit(0) + sc.stop() } } - diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 65d67356be2f6..700121d16dd60 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -30,7 +30,7 @@ object HBaseTest { val conf = HBaseConfiguration.create() - // Other options for configuring scan behavior are available. More information available at + // Other options for configuring scan behavior are available. More information available at // http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html conf.set(TableInputFormat.INPUT_TABLE, args(1)) @@ -41,12 +41,12 @@ object HBaseTest { admin.createTable(tableDesc) } - val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], + val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat], classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable], classOf[org.apache.hadoop.hbase.client.Result]) hBaseRDD.count() - System.exit(0) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala index c3597d94a224e..dd6d5205133be 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HdfsTest.scala @@ -32,6 +32,6 @@ object HdfsTest { val end = System.currentTimeMillis() println("Iteration " + iter + " took " + (end-start) + " ms") } - System.exit(0) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 0095cb8425456..37ad4bd0999bd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -120,7 +120,7 @@ object LocalALS { } } printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) - + val R = generateR() // Initialize m and u randomly diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 4aef04fc060b6..97321ab8f41db 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -51,6 +51,6 @@ object MultiBroadcastTest { // Collect the small RDD so we can print the observed sizes locally. observedSizes.collect().foreach(i => println(i)) - System.exit(0) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 1fdb324b89f3a..d05eedd31caa0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -27,7 +27,7 @@ object SimpleSkewedGroupByTest { System.err.println("Usage: SimpleSkewedGroupByTest " + "[numMappers] [numKVPairs] [valSize] [numReducers] [ratio]") System.exit(1) - } + } var numMappers = if (args.length > 1) args(1).toInt else 2 var numKVPairs = if (args.length > 2) args(2).toInt else 1000 @@ -58,14 +58,13 @@ object SimpleSkewedGroupByTest { }.cache // Enforce that everything has been calculated and in cache pairs1.count - + println("RESULT: " + pairs1.groupByKey(numReducers).count) // Print how many keys each reducer got (for debugging) // println("RESULT: " + pairs1.groupByKey(numReducers) // .map{case (k,v) => (k, v.size)} // .collectAsMap) - System.exit(0) + sc.stop() } } - diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 966478fe4a258..fd9f043247d18 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -27,7 +27,7 @@ object SkewedGroupByTest { System.err.println( "Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers]") System.exit(1) - } + } var numMappers = if (args.length > 1) args(1).toInt else 2 var numKVPairs = if (args.length > 2) args(2).toInt else 1000 @@ -53,10 +53,9 @@ object SkewedGroupByTest { }.cache() // Enforce that everything has been calculated and in cache pairs1.count() - + println(pairs1.groupByKey(numReducers).count()) - System.exit(0) + sc.stop() } } - diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala index f59ab7e7cc24a..68f151a2c47fe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala @@ -112,7 +112,7 @@ object SparkALS { val sc = new SparkContext(host, "SparkALS", System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass)) - + val R = generateR() // Initialize m and u randomly @@ -137,6 +137,6 @@ object SparkALS { println() } - System.exit(0) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index e698b9bf376e1..d8de8745c15d9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -52,7 +52,7 @@ object SparkHdfsLR { val inputPath = args(1) val conf = SparkHadoopUtil.get.newConfiguration() val sc = new SparkContext(args(0), "SparkHdfsLR", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass), Map(), + System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass), Map(), InputFormatInfo.computePreferredLocations( Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) )) @@ -73,6 +73,6 @@ object SparkHdfsLR { } println("Final w: " + w) - System.exit(0) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index 9fe24652358f3..1a8b21618e23a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -28,16 +28,16 @@ import org.apache.spark.SparkContext._ object SparkKMeans { val R = 1000 // Scaling factor val rand = new Random(42) - + def parseVector(line: String): Vector = { new Vector(line.split(' ').map(_.toDouble)) } - + def closestPoint(p: Vector, centers: Array[Vector]): Int = { var index = 0 var bestIndex = 0 var closest = Double.PositiveInfinity - + for (i <- 0 until centers.length) { val tempDist = p.squaredDist(centers(i)) if (tempDist < closest) { @@ -45,7 +45,7 @@ object SparkKMeans { bestIndex = i } } - + bestIndex } @@ -60,22 +60,22 @@ object SparkKMeans { val data = lines.map(parseVector _).cache() val K = args(2).toInt val convergeDist = args(3).toDouble - + val kPoints = data.takeSample(withReplacement = false, K, 42).toArray var tempDist = 1.0 while(tempDist > convergeDist) { val closest = data.map (p => (closestPoint(p, kPoints), (p, 1))) - + val pointStats = closest.reduceByKey{case ((x1, y1), (x2, y2)) => (x1 + x2, y1 + y2)} - + val newPoints = pointStats.map {pair => (pair._1, pair._2._1 / pair._2._2)}.collectAsMap() - + tempDist = 0.0 for (i <- 0 until K) { tempDist += kPoints(i).squaredDist(newPoints(i)) } - + for (newP <- newPoints) { kPoints(newP._1) = newP._2 } @@ -84,6 +84,6 @@ object SparkKMeans { println("Final centers:") kPoints.foreach(println) - System.exit(0) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index c54a55bdb4a11..3a2699d4d996b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -66,6 +66,6 @@ object SparkLR { } println("Final w: " + w) - System.exit(0) + sc.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index d203f4d20e15f..45b6e10f3ea9e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -57,7 +57,6 @@ object SparkPageRank { val output = ranks.collect() output.foreach(tup => println(tup._1 + " has rank: " + tup._2 + ".")) - System.exit(0) + ctx.stop() } } - diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala index e5a09ecec006f..d3babc3ed12c8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPi.scala @@ -18,8 +18,8 @@ package org.apache.spark.examples import scala.math.random + import org.apache.spark._ -import SparkContext._ /** Computes an approximation to pi */ object SparkPi { diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 24e8afa26bc5f..eb47cf027cb10 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -70,6 +70,6 @@ object SparkTC { } while (nextCount != oldCount) println("TC has " + tc.count() + " edges.") - System.exit(0) + spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala new file mode 100644 index 0000000000000..5698d4746495d --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples + +import java.util.Random +import scala.math.exp +import org.apache.spark.util.Vector +import org.apache.spark._ +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.scheduler.InputFormatInfo +import org.apache.spark.storage.StorageLevel + +/** + * Logistic regression based classification. + * This example uses Tachyon to persist rdds during computation. + */ +object SparkTachyonHdfsLR { + val D = 10 // Numer of dimensions + val rand = new Random(42) + + case class DataPoint(x: Vector, y: Double) + + def parsePoint(line: String): DataPoint = { + val tok = new java.util.StringTokenizer(line, " ") + var y = tok.nextToken.toDouble + var x = new Array[Double](D) + var i = 0 + while (i < D) { + x(i) = tok.nextToken.toDouble; i += 1 + } + DataPoint(new Vector(x), y) + } + + def main(args: Array[String]) { + if (args.length < 3) { + System.err.println("Usage: SparkTachyonHdfsLR ") + System.exit(1) + } + val inputPath = args(1) + val conf = SparkHadoopUtil.get.newConfiguration() + val sc = new SparkContext(args(0), "SparkTachyonHdfsLR", + System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass), Map(), + InputFormatInfo.computePreferredLocations( + Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) + )) + val lines = sc.textFile(inputPath) + val points = lines.map(parsePoint _).persist(StorageLevel.OFF_HEAP) + val ITERATIONS = args(2).toInt + + // Initialize w to a random value + var w = Vector(D, _ => 2 * rand.nextDouble - 1) + println("Initial w: " + w) + + for (i <- 1 to ITERATIONS) { + println("On iteration " + i) + val gradient = points.map { p => + (1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x + }.reduce(_ + _) + w -= gradient + } + + println("Final w: " + w) + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparkPCA.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala similarity index 54% rename from examples/src/main/scala/org/apache/spark/examples/mllib/SparkPCA.scala rename to examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala index d4e08c5e12d81..2b207fd8d3e16 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparkPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonPi.scala @@ -15,37 +15,38 @@ * limitations under the License. */ -package org.apache.spark.examples.mllib - -import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg.PCA -import org.apache.spark.mllib.linalg.MatrixEntry -import org.apache.spark.mllib.linalg.SparseMatrix -import org.apache.spark.mllib.util._ +package org.apache.spark.examples +import scala.math.random + +import org.apache.spark._ +import org.apache.spark.storage.StorageLevel /** - * Compute PCA of an example matrix. + * Computes an approximation to pi + * This example uses Tachyon to persist rdds during computation. */ -object SparkPCA { +object SparkTachyonPi { def main(args: Array[String]) { - if (args.length != 3) { - System.err.println("Usage: SparkPCA m n") + if (args.length == 0) { + System.err.println("Usage: SparkTachyonPi []") System.exit(1) } - val sc = new SparkContext(args(0), "PCA", + val spark = new SparkContext(args(0), "SparkTachyonPi", System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass)) - val m = args(2).toInt - val n = args(3).toInt - - // Make example matrix - val data = Array.tabulate(m, n) { (a, b) => - (a + 2).toDouble * (b + 1) / (1 + a + b) } + val slices = if (args.length > 1) args(1).toInt else 2 + val n = 100000 * slices - // recover top principal component - val coeffs = new PCA().setK(1).compute(sc.makeRDD(data)) + val rdd = spark.parallelize(1 to n, slices) + rdd.persist(StorageLevel.OFF_HEAP) + val count = rdd.map { i => + val x = random * 2 - 1 + val y = random * 2 - 1 + if (x * x + y * y < 1) 1 else 0 + }.reduce(_ + _) + println("Pi is roughly " + 4.0 * count / n) - println("top principal component = " + coeffs.mkString(", ")) + spark.stop() } } diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala index 27afa6b642758..dee3cb6c0abae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -79,7 +79,7 @@ object WikipediaPageRankStandalone { val time = (System.currentTimeMillis - startTime) / 1000.0 println("Completed %d iterations in %f seconds: %f seconds per iteration" .format(numIterations, time, time / numIterations)) - System.exit(0) + sc.stop() } def parseArticle(line: String): (String, Array[String]) = { @@ -115,12 +115,16 @@ object WikipediaPageRankStandalone { var ranks = links.mapValues { edges => defaultRank } for (i <- 1 to numIterations) { val contribs = links.groupWith(ranks).flatMap { - case (id, (linksWrapper, rankWrapper)) => - if (linksWrapper.length > 0) { - if (rankWrapper.length > 0) { - linksWrapper(0).map(dest => (dest, rankWrapper(0) / linksWrapper(0).size)) + case (id, (linksWrapperIterable, rankWrapperIterable)) => + val linksWrapper = linksWrapperIterable.iterator + val rankWrapper = rankWrapperIterable.iterator + if (linksWrapper.hasNext) { + val linksWrapperHead = linksWrapper.next + if (rankWrapper.hasNext) { + val rankWrapperHead = rankWrapper.next + linksWrapperHead.map(dest => (dest, rankWrapperHead / linksWrapperHead.size)) } else { - linksWrapper(0).map(dest => (dest, defaultRank / linksWrapper(0).size)) + linksWrapperHead.map(dest => (dest, defaultRank / linksWrapperHead.size)) } } else { Array[(String, Double)]() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SparkSVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SparkSVD.scala deleted file mode 100644 index 2933cec497b37..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SparkSVD.scala +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.mllib - -import org.apache.spark.SparkContext -import org.apache.spark.mllib.linalg.SVD -import org.apache.spark.mllib.linalg.MatrixEntry -import org.apache.spark.mllib.linalg.SparseMatrix - -/** - * Compute SVD of an example matrix - * Input file should be comma separated, 1 indexed of the form - * i,j,value - * Where i is the column, j the row, and value is the matrix entry - * - * For example input file, see: - * mllib/data/als/test.data (example is 4 x 4) - */ -object SparkSVD { - def main(args: Array[String]) { - if (args.length != 4) { - System.err.println("Usage: SparkSVD m n") - System.exit(1) - } - val sc = new SparkContext(args(0), "SVD", - System.getenv("SPARK_HOME"), SparkContext.jarOfClass(this.getClass)) - - // Load and parse the data file - val data = sc.textFile(args(1)).map { line => - val parts = line.split(',') - MatrixEntry(parts(0).toInt - 1, parts(1).toInt - 1, parts(2).toDouble) - } - val m = args(2).toInt - val n = args(3).toInt - - // recover largest singular vector - val decomposed = new SVD().setK(1).compute(SparseMatrix(data, m, n)) - val u = decomposed.U.data - val s = decomposed.S.data - val v = decomposed.V.data - - println("singular values = " + s.collect().mkString) - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala new file mode 100644 index 0000000000000..61b9655cd3759 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.linalg.Vectors + +/** + * Compute the principal components of a tall-and-skinny matrix, whose rows are observations. + * + * The input matrix must be stored in row-oriented dense format, one line per row with its entries + * separated by space. For example, + * {{{ + * 0.5 1.0 + * 2.0 3.0 + * 4.0 5.0 + * }}} + * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). + */ +object TallSkinnyPCA { + def main(args: Array[String]) { + if (args.length != 2) { + System.err.println("Usage: TallSkinnyPCA ") + System.exit(1) + } + + val conf = new SparkConf() + .setMaster(args(0)) + .setAppName("TallSkinnyPCA") + .setSparkHome(System.getenv("SPARK_HOME")) + .setJars(SparkContext.jarOfClass(this.getClass)) + val sc = new SparkContext(conf) + + // Load and parse the data file. + val rows = sc.textFile(args(1)).map { line => + val values = line.split(' ').map(_.toDouble) + Vectors.dense(values) + } + val mat = new RowMatrix(rows) + + // Compute principal components. + val pc = mat.computePrincipalComponents(mat.numCols().toInt) + + println("Principal components are:\n" + pc) + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala new file mode 100644 index 0000000000000..9aeebf58eabfb --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.linalg.Vectors + +/** + * Compute the singular value decomposition (SVD) of a tall-and-skinny matrix. + * + * The input matrix must be stored in row-oriented dense format, one line per row with its entries + * separated by space. For example, + * {{{ + * 0.5 1.0 + * 2.0 3.0 + * 4.0 5.0 + * }}} + * represents a 3-by-2 matrix, whose first row is (0.5, 1.0). + */ +object TallSkinnySVD { + def main(args: Array[String]) { + if (args.length != 2) { + System.err.println("Usage: TallSkinnySVD ") + System.exit(1) + } + + val conf = new SparkConf() + .setMaster(args(0)) + .setAppName("TallSkinnySVD") + .setSparkHome(System.getenv("SPARK_HOME")) + .setJars(SparkContext.jarOfClass(this.getClass)) + val sc = new SparkContext(conf) + + // Load and parse the data file. + val rows = sc.textFile(args(1)).map { line => + val values = line.split(' ').map(_.toDouble) + Vectors.dense(values) + } + val mat = new RowMatrix(rows) + + // Compute SVD. + val svd = mat.computeSVD(mat.numCols().toInt) + + println("Singular values are " + svd.s) + + sc.stop() + } +} diff --git a/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala index abcc1f04d4279..62329bde84481 100644 --- a/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/sql/examples/HiveFromSpark.scala @@ -33,20 +33,20 @@ object HiveFromSpark { val hiveContext = new LocalHiveContext(sc) import hiveContext._ - sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") - sql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") + hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + hql("LOAD DATA LOCAL INPATH 'src/main/resources/kv1.txt' INTO TABLE src") // Queries are expressed in HiveQL println("Result of 'SELECT *': ") - sql("SELECT * FROM src").collect.foreach(println) + hql("SELECT * FROM src").collect.foreach(println) // Aggregation queries are also supported. - val count = sql("SELECT COUNT(*) FROM src").collect().head.getInt(0) + val count = hql("SELECT COUNT(*) FROM src").collect().head.getInt(0) println(s"COUNT(*): $count") // The results of SQL queries are themselves RDDs and support all normal RDD functions. The // items in the RDD are of type Row, which allows you to access each column by ordinal. - val rddFromSql = sql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") + val rddFromSql = hql("SELECT key, value FROM src WHERE key < 10 ORDER BY key") println("Result of RDD.map:") val rddAsStrings = rddFromSql.map { @@ -59,6 +59,6 @@ object HiveFromSpark { // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") - sql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) + hql("SELECT * FROM records r JOIN src s ON r.key = s.key").collect().foreach(println) } } diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala index 954bcc9b6ef5d..1c0ce3111e290 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/HdfsWordCount.scala @@ -53,4 +53,3 @@ object HdfsWordCount { ssc.awaitTermination() } } - diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala index 6bccd1d88401a..cca0be2cbb9c9 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/KafkaWordCount.scala @@ -61,7 +61,7 @@ object KafkaWordCount { val wordCounts = words.map(x => (x, 1L)) .reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2) wordCounts.print() - + ssc.start() ssc.awaitTermination() } @@ -83,7 +83,7 @@ object KafkaWordCountProducer { val props = new Properties() props.put("metadata.broker.list", brokers) props.put("serializer.class", "kafka.serializer.StringEncoder") - + val config = new ProducerConfig(props) val producer = new Producer[String, String](config) @@ -102,4 +102,3 @@ object KafkaWordCountProducer { } } - diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala index 0a68ac84c2424..656222e0c1b31 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala @@ -26,7 +26,7 @@ import org.apache.spark.streaming.StreamingContext._ import org.apache.spark.streaming.mqtt._ /** - * A simple Mqtt publisher for demonstration purposes, repeatedly publishes + * A simple Mqtt publisher for demonstration purposes, repeatedly publishes * Space separated String Message "hello mqtt demo for spark streaming" */ object MQTTPublisher { diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala index 4d4968ba6ae3e..612ecf7b7821a 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/QueueStream.scala @@ -24,7 +24,7 @@ import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.StreamingContext._ object QueueStream { - + def main(args: Array[String]) { if (args.length < 1) { System.err.println("Usage: QueueStream ") @@ -37,23 +37,22 @@ object QueueStream { val ssc = new StreamingContext(args(0), "QueueStream", Seconds(1), System.getenv("SPARK_HOME"), StreamingContext.jarOfClass(this.getClass)) - // Create the queue through which RDDs can be pushed to + // Create the queue through which RDDs can be pushed to // a QueueInputDStream val rddQueue = new SynchronizedQueue[RDD[Int]]() - + // Create the QueueInputDStream and use it do some processing val inputStream = ssc.queueStream(rddQueue) val mappedStream = inputStream.map(x => (x % 10, 1)) val reducedStream = mappedStream.reduceByKey(_ + _) - reducedStream.print() + reducedStream.print() ssc.start() - + // Create and push some RDDs into for (i <- 1 to 30) { rddQueue += ssc.sparkContext.makeRDD(1 to 1000, 10) Thread.sleep(1000) } ssc.stop() - System.exit(0) } } diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala index c2d84a8e0861e..14f65a2f8d46c 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/StatefulNetworkWordCount.scala @@ -58,7 +58,7 @@ object StatefulNetworkWordCount { ssc.checkpoint(".") // Create a NetworkInputDStream on target ip:port and count the - // words in input stream of \n delimited test (eg. generated by 'nc') + // words in input stream of \n delimited test (eg. generated by 'nc') val lines = ssc.socketTextStream(args(1), args(2).toInt) val words = lines.flatMap(_.split(" ")) val wordDstream = words.map(x => (x, 1)) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala index 35f8f885f8f0e..445d2028582af 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ZeroMQWordCount.scala @@ -60,7 +60,7 @@ object SimpleZeroMQPublisher { * To work with zeroMQ, some native libraries have to be installed. * Install zeroMQ (release 2.1) core libraries. [ZeroMQ Install guide] * (http://www.zeromq.org/intro:get-the-software) - * + * * Usage: ZeroMQWordCount * In local mode, should be 'local[n]' with n > 1 * and describe where zeroMq publisher is running. diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 2896e42019fe2..c9c85f0a88f13 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -49,7 +49,7 @@ import org.apache.spark.streaming.receiver.NetworkReceiver * @param storageLevel RDD storage level. */ -private[streaming] +private[streaming] class MQTTInputDStream( @transient ssc_ : StreamingContext, brokerUrl: String, @@ -79,7 +79,7 @@ class MQTTReceiver( // Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance val client: MqttClient = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) - // Connect to MqttBroker + // Connect to MqttBroker client.connect() // Subscribe to Mqtt topic @@ -88,7 +88,7 @@ class MQTTReceiver( // Callback automatically triggers as and when new message arrives on specified topic val callback: MqttCallback = new MqttCallback() { - // Handles Mqtt message + // Handles Mqtt message override def messageArrived(arg0: String, arg1: MqttMessage) { store(new String(arg1.getPayload())) } diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 59957c05c9f76..372b4c269a634 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -33,7 +33,7 @@ import org.apache.spark.streaming.receiver.NetworkReceiver * @constructor create a new Twitter stream using the supplied Twitter4J authentication credentials. * An optional set of string filters can be used to restrict the set of tweets. The Twitter API is * such that this may return a sampled subset of all tweets during each interval. -* +* * If no Authorization object is provided, initializes OAuth authorization using the system * properties twitter4j.oauth.consumerKey, .consumerSecret, .accessToken and .accessTokenSecret. */ @@ -44,13 +44,13 @@ class TwitterInputDStream( filters: Seq[String], storageLevel: StorageLevel ) extends NetworkInputDStream[Status](ssc_) { - + private def createOAuthAuthorization(): Authorization = { new OAuthAuthorization(new ConfigurationBuilder().build()) } private val authorization = twitterAuth.getOrElse(createOAuthAuthorization()) - + override def getReceiver(): NetworkReceiver[Status] = { new TwitterReceiver(authorization, filters, storageLevel) } diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index f67251217ed4a..7eb8b45fc3cf0 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -23,6 +23,7 @@ import scala.Tuple2; +import com.google.common.collections.Iterables; import com.google.common.base.Optional; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; @@ -85,15 +86,15 @@ public void foreach() { public void groupBy() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 1, 2, 3, 5, 8, 13)); Function isOdd = x -> x % 2 == 0; - JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); + JavaPairRDD> oddsAndEvens = rdd.groupBy(isOdd); Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens - Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds oddsAndEvens = rdd.groupBy(isOdd, 1); Assert.assertEquals(2, oddsAndEvens.count()); - Assert.assertEquals(2, oddsAndEvens.lookup(true).get(0).size()); // Evens - Assert.assertEquals(5, oddsAndEvens.lookup(false).get(0).size()); // Odds + Assert.assertEquals(2, Iterables.size(oddsAndEvens.lookup(true).get(0))); // Evens + Assert.assertEquals(5, Iterables.size(oddsAndEvens.lookup(false).get(0))); // Odds } @Test diff --git a/graphx/pom.xml b/graphx/pom.xml index 5a5022916d234..b4c67ddcd8ca9 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -54,7 +54,7 @@ org.jblas jblas - 1.2.3 + ${jblas.version} org.eclipse.jetty diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index f2296a865e1b3..6d04bf790e3a5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -45,7 +45,8 @@ class EdgeRDD[@specialized ED: ClassTag]( partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD))) override def compute(part: Partition, context: TaskContext): Iterator[Edge[ED]] = { - firstParent[(PartitionID, EdgePartition[ED])].iterator(part, context).next._2.iterator + val p = firstParent[(PartitionID, EdgePartition[ED])].iterator(part, context) + p.next._2.iterator.map(_.copy()) } override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 377d9d6bd5e72..5635287694ee2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -172,7 +172,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali "EdgeDirection.Either instead.") } } - + /** * Join the vertices with an RDD and then apply a function from the * the vertex and RDD entry to a new vertex value. The input table diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala index 57fa5eefd5e09..2e05f5d4e4969 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartition.scala @@ -56,6 +56,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) * Construct a new edge partition by applying the function f to all * edges in this partition. * + * Be careful not to keep references to the objects passed to `f`. + * To improve GC performance the same object is re-used for each call. + * * @param f a function from an edge to a new attribute * @tparam ED2 the type of the new attribute * @return a new edge partition with the result of the function `f` @@ -84,12 +87,12 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) * order of the edges returned by `EdgePartition.iterator` and * should return attributes equal to the number of edges. * - * @param f a function from an edge to a new attribute + * @param iter an iterator for the new attribute values * @tparam ED2 the type of the new attribute - * @return a new edge partition with the result of the function `f` - * applied to each edge + * @return a new edge partition with the attribute values replaced */ def map[ED2: ClassTag](iter: Iterator[ED2]): EdgePartition[ED2] = { + // Faster than iter.toArray, because the expected size is known. val newData = new Array[ED2](data.size) var i = 0 while (iter.hasNext) { @@ -188,6 +191,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) /** * Get an iterator over the edges in this partition. * + * Be careful not to keep references to the objects from this iterator. + * To improve GC performance the same object is re-used in `next()`. + * * @return an iterator over edges in the partition */ def iterator = new Iterator[Edge[ED]] { @@ -216,6 +222,9 @@ class EdgePartition[@specialized(Char, Int, Boolean, Byte, Long, Float, Double) /** * Get an iterator over the cluster of edges in this partition with source vertex id `srcId`. The * cluster must start at position `index`. + * + * Be careful not to keep references to the objects from this iterator. To improve GC performance + * the same object is re-used in `next()`. */ private def clusterIterator(srcId: VertexId, index: Int) = new Iterator[Edge[ED]] { private[this] val edge = new Edge[ED] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala index 886c250d7cffd..220a89d73d711 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeTripletIterator.scala @@ -37,20 +37,15 @@ class EdgeTripletIterator[VD: ClassTag, ED: ClassTag]( // Current position in the array. private var pos = 0 - // A triplet object that this iterator.next() call returns. We reuse this object to avoid - // allocating too many temporary Java objects. - private val triplet = new EdgeTriplet[VD, ED] - private val vmap = new PrimitiveKeyOpenHashMap[VertexId, VD](vidToIndex, vertexArray) override def hasNext: Boolean = pos < edgePartition.size override def next() = { + val triplet = new EdgeTriplet[VD, ED] triplet.srcId = edgePartition.srcIds(pos) - // assert(vmap.containsKey(e.src.id)) triplet.srcAttr = vmap(triplet.srcId) triplet.dstId = edgePartition.dstIds(pos) - // assert(vmap.containsKey(e.dst.id)) triplet.dstAttr = vmap(triplet.dstId) triplet.attr = edgePartition.data(pos) pos += 1 diff --git a/graphx/src/main/scala/org/apache/spark/graphx/package.scala b/graphx/src/main/scala/org/apache/spark/graphx/package.scala index 425a5164cad24..ff17edeaf8f16 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/package.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/package.scala @@ -19,7 +19,10 @@ package org.apache.spark import org.apache.spark.util.collection.OpenHashSet -/** GraphX is a graph processing framework built on top of Spark. */ +/** + * ALPHA COMPONENT + * GraphX is a graph processing framework built on top of Spark. + */ package object graphx { /** * A 64-bit vertex identifier that uniquely identifies a vertex within a graph. It does not need diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala index 6386306c048fc..a467ca1ae715a 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphOpsSuite.scala @@ -55,7 +55,7 @@ class GraphOpsSuite extends FunSuite with LocalSparkContext { } } } - + test ("filter") { withSpark { sc => val n = 5 diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala new file mode 100644 index 0000000000000..9cbb2d2acdc2d --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx.impl + +import scala.reflect.ClassTag +import scala.util.Random + +import org.scalatest.FunSuite + +import org.apache.spark.graphx._ + +class EdgeTripletIteratorSuite extends FunSuite { + test("iterator.toList") { + val builder = new EdgePartitionBuilder[Int] + builder.add(1, 2, 0) + builder.add(1, 3, 0) + builder.add(1, 4, 0) + val vidmap = new VertexIdToIndexMap + vidmap.add(1) + vidmap.add(2) + vidmap.add(3) + vidmap.add(4) + val vs = Array.fill(vidmap.capacity)(0) + val iter = new EdgeTripletIterator[Int, Int](vidmap, vs, builder.toEdgePartition) + val result = iter.toList.map(et => (et.srcId, et.dstId)) + assert(result === Seq((1, 2), (1, 3), (1, 4))) + } +} diff --git a/mllib/pom.xml b/mllib/pom.xml index fec1cc94b2642..e7ce00efc4af6 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -58,7 +58,7 @@ org.jblas jblas - 1.2.3 + ${jblas.version} org.scalanlp diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 3449c698da60b..a6c049e517ee0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ @@ -28,8 +29,10 @@ import org.apache.spark.mllib.regression._ import org.apache.spark.rdd.RDD /** + * :: DeveloperApi :: * The Java stubs necessary for the Python mllib bindings. */ +@DeveloperApi class PythonMLLibAPI extends Serializable { private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { val packetLength = bytes.length @@ -110,16 +113,16 @@ class PythonMLLibAPI extends Serializable { private def trainRegressionModel( trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, - dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): - java.util.LinkedList[java.lang.Object] = { + dataBytesJRDD: JavaRDD[Array[Byte]], + initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => { val x = deserializeDoubleVector(xBytes) - LabeledPoint(x(0), x.slice(1, x.length)) + LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) }) val initialWeights = deserializeDoubleVector(initialWeightsBA) val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(model.weights)) + ret.add(serializeDoubleVector(model.weights.toArray)) ret.add(model.intercept: java.lang.Double) ret } @@ -127,75 +130,127 @@ class PythonMLLibAPI extends Serializable { /** * Java stub for Python mllib LinearRegressionWithSGD.train() */ - def trainLinearRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], - numIterations: Int, stepSize: Double, miniBatchFraction: Double, + def trainLinearRegressionModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - LinearRegressionWithSGD.train(data, numIterations, stepSize, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + LinearRegressionWithSGD.train( + data, + numIterations, + stepSize, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for Python mllib LassoWithSGD.train() */ - def trainLassoModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, - stepSize: Double, regParam: Double, miniBatchFraction: Double, + def trainLassoModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - LassoWithSGD.train(data, numIterations, stepSize, regParam, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + LassoWithSGD.train( + data, + numIterations, + stepSize, + regParam, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for Python mllib RidgeRegressionWithSGD.train() */ - def trainRidgeModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, - stepSize: Double, regParam: Double, miniBatchFraction: Double, + def trainRidgeModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - RidgeRegressionWithSGD.train(data, numIterations, stepSize, regParam, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + RidgeRegressionWithSGD.train( + data, + numIterations, + stepSize, + regParam, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for Python mllib SVMWithSGD.train() */ - def trainSVMModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], numIterations: Int, - stepSize: Double, regParam: Double, miniBatchFraction: Double, + def trainSVMModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + regParam: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - SVMWithSGD.train(data, numIterations, stepSize, regParam, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + SVMWithSGD.train( + data, + numIterations, + stepSize, + regParam, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for Python mllib LogisticRegressionWithSGD.train() */ - def trainLogisticRegressionModelWithSGD(dataBytesJRDD: JavaRDD[Array[Byte]], - numIterations: Int, stepSize: Double, miniBatchFraction: Double, + def trainLogisticRegressionModelWithSGD( + dataBytesJRDD: JavaRDD[Array[Byte]], + numIterations: Int, + stepSize: Double, + miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { - trainRegressionModel((data, initialWeights) => - LogisticRegressionWithSGD.train(data, numIterations, stepSize, - miniBatchFraction, initialWeights), - dataBytesJRDD, initialWeightsBA) + trainRegressionModel( + (data, initialWeights) => + LogisticRegressionWithSGD.train( + data, + numIterations, + stepSize, + miniBatchFraction, + Vectors.dense(initialWeights)), + dataBytesJRDD, + initialWeightsBA) } /** * Java stub for NaiveBayes.train() */ - def trainNaiveBayes(dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double) - : java.util.List[java.lang.Object] = - { + def trainNaiveBayes( + dataBytesJRDD: JavaRDD[Array[Byte]], + lambda: Double): java.util.List[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => { val x = deserializeDoubleVector(xBytes) - LabeledPoint(x(0), x.slice(1, x.length)) + LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) }) val model = NaiveBayes.train(data, lambda) val ret = new java.util.LinkedList[java.lang.Object]() + ret.add(serializeDoubleVector(model.labels)) ret.add(serializeDoubleVector(model.pi)) ret.add(serializeDoubleMatrix(model.theta)) ret @@ -204,9 +259,12 @@ class PythonMLLibAPI extends Serializable { /** * Java stub for Python mllib KMeans.train() */ - def trainKMeansModel(dataBytesJRDD: JavaRDD[Array[Byte]], k: Int, - maxIterations: Int, runs: Int, initializationMode: String): - java.util.List[java.lang.Object] = { + def trainKMeansModel( + dataBytesJRDD: JavaRDD[Array[Byte]], + k: Int, + maxIterations: Int, + runs: Int, + initializationMode: String): java.util.List[java.lang.Object] = { val data = dataBytesJRDD.rdd.map(xBytes => Vectors.dense(deserializeDoubleVector(xBytes))) val model = KMeans.train(data, k, maxIterations, runs, initializationMode) val ret = new java.util.LinkedList[java.lang.Object]() @@ -259,8 +317,12 @@ class PythonMLLibAPI extends Serializable { * needs to be taken in the Python code to ensure it gets freed on exit; see * the Py4J documentation. */ - def trainALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, - iterations: Int, lambda: Double, blocks: Int): MatrixFactorizationModel = { + def trainALSModel( + ratingsBytesJRDD: JavaRDD[Array[Byte]], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int): MatrixFactorizationModel = { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.train(ratings, rank, iterations, lambda, blocks) } @@ -271,8 +333,13 @@ class PythonMLLibAPI extends Serializable { * Extra care needs to be taken in the Python code to ensure it gets freed on * exit; see the Py4J documentation. */ - def trainImplicitALSModel(ratingsBytesJRDD: JavaRDD[Array[Byte]], rank: Int, - iterations: Int, lambda: Double, blocks: Int, alpha: Double): MatrixFactorizationModel = { + def trainImplicitALSModel( + ratingsBytesJRDD: JavaRDD[Array[Byte]], + rank: Int, + iterations: Int, + lambda: Double, + blocks: Int, + alpha: Double): MatrixFactorizationModel = { val ratings = ratingsBytesJRDD.rdd.map(unpackRating) ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 391f5b9b7a7de..bd10e2e9e10e2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -17,22 +17,27 @@ package org.apache.spark.mllib.classification +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD +/** + * Represents a classification model that predicts to which of a set of categories an example + * belongs. The categories are represented by double values: 0.0, 1.0, 2.0, etc. + */ trait ClassificationModel extends Serializable { /** * Predict values for the given data set using the model trained. * * @param testData RDD representing data points to be predicted - * @return RDD[Int] where each entry contains the corresponding prediction + * @return an RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Array[Double]]): RDD[Double] + def predict(testData: RDD[Vector]): RDD[Double] /** * Predict values for a single data point using the model trained. * * @param testData array representing a single data point - * @return Int prediction from the trained model + * @return predicted category from the trained model */ - def predict(testData: Array[Double]): Double + def predict(testData: Vector): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index a481f522761e2..4f9eaacf67fe4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -17,16 +17,12 @@ package org.apache.spark.mllib.classification -import scala.math.round - import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.DataValidators - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.rdd.RDD /** * Classification model trained using Logistic Regression. @@ -35,15 +31,38 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LogisticRegressionModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { + + private var threshold: Option[Double] = Some(0.5) + + /** + * Sets the threshold that separates positive predictions from negative predictions. An example + * with prediction score greater than or equal to this threshold is identified as an positive, + * and negative otherwise. The default value is 0.5. + */ + def setThreshold(threshold: Double): this.type = { + this.threshold = Some(threshold) + this + } - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + /** + * Clears the threshold so that `predict` will output raw prediction scores. + */ + def clearThreshold(): this.type = { + threshold = None + this + } + + override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double) = { - val margin = dataMatrix.mmul(weightMatrix).get(0) + intercept - round(1.0/ (1.0 + math.exp(margin * -1))) + val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept + val score = 1.0/ (1.0 + math.exp(-margin)) + threshold match { + case Some(t) => if (score < t) 0.0 else 1.0 + case None => score + } } } @@ -52,28 +71,27 @@ class LogisticRegressionModel( * NOTE: Labels used in Logistic Regression should be {0, 1} */ class LogisticRegressionWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LogisticRegressionModel] - with Serializable { - - val gradient = new LogisticGradient() - val updater = new SimpleUpdater() + private var stepSize: Double, + private var numIterations: Int, + private var regParam: Double, + private var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable { + + private val gradient = new LogisticGradient() + private val updater = new SimpleUpdater() override val optimizer = new GradientDescent(gradient, updater) - .setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) - override val validators = List(DataValidators.classificationLabels) + .setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + override protected val validators = List(DataValidators.binaryLabelValidator) /** * Construct a LogisticRegression object with default parameters */ def this() = this(1.0, 100, 0.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { + override protected def createModel(weights: Vector, intercept: Double) = { new LogisticRegressionModel(weights, intercept) } } @@ -105,11 +123,9 @@ object LogisticRegressionWithSGD { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : LogisticRegressionModel = - { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( - input, initialWeights) + initialWeights: Vector): LogisticRegressionModel = { + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) + .run(input, initialWeights) } /** @@ -128,11 +144,9 @@ object LogisticRegressionWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - miniBatchFraction: Double) - : LogisticRegressionModel = - { - new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run( - input) + miniBatchFraction: Double): LogisticRegressionModel = { + new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction) + .run(input) } /** @@ -150,9 +164,7 @@ object LogisticRegressionWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int, - stepSize: Double) - : LogisticRegressionModel = - { + stepSize: Double): LogisticRegressionModel = { train(input, numIterations, stepSize, 1.0) } @@ -168,9 +180,7 @@ object LogisticRegressionWithSGD { */ def train( input: RDD[LabeledPoint], - numIterations: Int) - : LogisticRegressionModel = - { + numIterations: Int): LogisticRegressionModel = { train(input, numIterations, 1.0, 1.0) } @@ -183,7 +193,7 @@ object LogisticRegressionWithSGD { val sc = new SparkContext(args(0), "LogisticRegression") val data = MLUtils.loadLabeledData(sc, args(1)) val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 6539b2f339465..18658850a2f64 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -17,34 +17,51 @@ package org.apache.spark.mllib.classification -import scala.collection.mutable +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} -import org.jblas.DoubleMatrix - -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.annotation.Experimental +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD /** + * :: Experimental :: * Model for Naive Bayes Classifiers. * - * @param pi Log of class priors, whose dimension is C. - * @param theta Log of class conditional probabilities, whose dimension is CxD. + * @param labels list of labels + * @param pi log of class priors, whose dimension is C, number of labels + * @param theta log of class conditional probabilities, whose dimension is C-by-D, + * where D is number of features */ -class NaiveBayesModel(val pi: Array[Double], val theta: Array[Array[Double]]) - extends ClassificationModel with Serializable { - - // Create a column vector that can be used for predictions - private val _pi = new DoubleMatrix(pi.length, 1, pi: _*) - private val _theta = new DoubleMatrix(theta) +@Experimental +class NaiveBayesModel( + val labels: Array[Double], + val pi: Array[Double], + val theta: Array[Array[Double]]) extends ClassificationModel with Serializable { + + private val brzPi = new BDV[Double](pi) + private val brzTheta = new BDM[Double](theta.length, theta(0).length) + + { + // Need to put an extra pair of braces to prevent Scala treating `i` as a member. + var i = 0 + while (i < theta.length) { + var j = 0 + while (j < theta(i).length) { + brzTheta(i, j) = theta(i)(j) + j += 1 + } + i += 1 + } + } - def predict(testData: RDD[Array[Double]]): RDD[Double] = testData.map(predict) + override def predict(testData: RDD[Vector]): RDD[Double] = testData.map(predict) - def predict(testData: Array[Double]): Double = { - val dataMatrix = new DoubleMatrix(testData.length, 1, testData: _*) - val result = _pi.add(_theta.mmul(dataMatrix)) - result.argmax() + override def predict(testData: Vector): Double = { + labels(brzArgmax(brzPi + brzTheta * testData.toBreeze)) } } @@ -56,9 +73,8 @@ class NaiveBayesModel(val pi: Array[Double], val theta: Array[Array[Double]]) * document classification. By making every vector a 0-1 vector, it can also be used as * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). */ -class NaiveBayes private (var lambda: Double) - extends Serializable with Logging -{ +class NaiveBayes private (private var lambda: Double) extends Serializable with Logging { + def this() = this(1.0) /** Set the smoothing parameter. Default: 1.0. */ @@ -70,45 +86,42 @@ class NaiveBayes private (var lambda: Double) /** * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * - * @param data RDD of (label, array of features) pairs. + * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. */ def run(data: RDD[LabeledPoint]) = { - // Aggregates all sample points to driver side to get sample count and summed feature vector - // for each label. The shape of `zeroCombiner` & `aggregated` is: - // - // label: Int -> (count: Int, featuresSum: DoubleMatrix) - val zeroCombiner = mutable.Map.empty[Int, (Int, DoubleMatrix)] - val aggregated = data.aggregate(zeroCombiner)({ (combiner, point) => - point match { - case LabeledPoint(label, features) => - val (count, featuresSum) = combiner.getOrElse(label.toInt, (0, DoubleMatrix.zeros(1))) - val fs = new DoubleMatrix(features.length, 1, features: _*) - combiner += label.toInt -> (count + 1, featuresSum.addi(fs)) - } - }, { (lhs, rhs) => - for ((label, (c, fs)) <- rhs) { - val (count, featuresSum) = lhs.getOrElse(label, (0, DoubleMatrix.zeros(1))) - lhs(label) = (count + c, featuresSum.addi(fs)) + // Aggregates term frequencies per label. + // TODO: Calling combineByKey and collect creates two stages, we can implement something + // TODO: similar to reduceByKeyLocally to save one stage. + val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( + createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector), + mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze), + mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => + (c1._1 + c2._1, c1._2 += c2._2) + ).collect() + val numLabels = aggregated.length + var numDocuments = 0L + aggregated.foreach { case (_, (n, _)) => + numDocuments += n + } + val numFeatures = aggregated.head match { case (_, (_, v)) => v.size } + val labels = new Array[Double](numLabels) + val pi = new Array[Double](numLabels) + val theta = Array.fill(numLabels)(new Array[Double](numFeatures)) + val piLogDenom = math.log(numDocuments + numLabels * lambda) + var i = 0 + aggregated.foreach { case (label, (n, sumTermFreqs)) => + labels(i) = label + val thetaLogDenom = math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + pi(i) = math.log(n + lambda) - piLogDenom + var j = 0 + while (j < numFeatures) { + theta(i)(j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom + j += 1 } - lhs - }) - - // Kinds of label - val C = aggregated.size - // Total sample count - val N = aggregated.values.map(_._1).sum - - val pi = new Array[Double](C) - val theta = new Array[Array[Double]](C) - val piLogDenom = math.log(N + C * lambda) - - for ((label, (count, fs)) <- aggregated) { - val thetaLogDenom = math.log(fs.sum() + fs.length * lambda) - pi(label) = math.log(count + lambda) - piLogDenom - theta(label) = fs.toArray.map(f => math.log(f + lambda) - thetaLogDenom) + i += 1 } - new NaiveBayesModel(pi, theta) + new NaiveBayesModel(labels, pi, theta) } } @@ -158,8 +171,9 @@ object NaiveBayes { } else { NaiveBayes.train(data, args(2).toDouble) } - println("Pi: " + model.pi.mkString("[", ", ", "]")) - println("Theta:\n" + model.theta.map(_.mkString("[", ", ", "]")).mkString("[", "\n ", "]")) + + println("Pi\n: " + model.pi) + println("Theta:\n" + model.theta) sc.stop() } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 6dff29dfb45cc..956654b1fe90a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -18,13 +18,11 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.DataValidators - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.util.{DataValidators, MLUtils} +import org.apache.spark.rdd.RDD /** * Model for Support Vector Machines (SVMs). @@ -33,15 +31,39 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class SVMModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with ClassificationModel with Serializable { + extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, + private var threshold: Option[Double] = Some(0.0) + + /** + * Sets the threshold that separates positive predictions from negative predictions. An example + * with prediction score greater than or equal to this threshold is identified as an positive, + * and negative otherwise. The default value is 0.0. + */ + def setThreshold(threshold: Double): this.type = { + this.threshold = Some(threshold) + this + } + + /** + * Clears the threshold so that `predict` will output raw prediction scores. + */ + def clearThreshold(): this.type = { + threshold = None + this + } + + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double) = { - val margin = dataMatrix.dot(weightMatrix) + intercept - if (margin < 0) 0.0 else 1.0 + val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept + threshold match { + case Some(t) => if (margin < 0) 0.0 else 1.0 + case None => margin + } } } @@ -50,28 +72,27 @@ class SVMModel( * NOTE: Labels used in SVM should be {0, 1}. */ class SVMWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) + private var stepSize: Double, + private var numIterations: Int, + private var regParam: Double, + private var miniBatchFraction: Double) extends GeneralizedLinearAlgorithm[SVMModel] with Serializable { - val gradient = new HingeGradient() - val updater = new SquaredL2Updater() + private val gradient = new HingeGradient() + private val updater = new SquaredL2Updater() override val optimizer = new GradientDescent(gradient, updater) .setStepSize(stepSize) .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - - override val validators = List(DataValidators.classificationLabels) + override protected val validators = List(DataValidators.binaryLabelValidator) /** * Construct a SVM object with default parameters */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { + override protected def createModel(weights: Vector, intercept: Double) = { new SVMModel(weights, intercept) } } @@ -103,11 +124,9 @@ object SVMWithSGD { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : SVMModel = - { - new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, - initialWeights) + initialWeights: Vector): SVMModel = { + new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction) + .run(input, initialWeights) } /** @@ -127,9 +146,7 @@ object SVMWithSGD { numIterations: Int, stepSize: Double, regParam: Double, - miniBatchFraction: Double) - : SVMModel = - { + miniBatchFraction: Double): SVMModel = { new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } @@ -149,9 +166,7 @@ object SVMWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - regParam: Double) - : SVMModel = - { + regParam: Double): SVMModel = { train(input, numIterations, stepSize, regParam, 1.0) } @@ -165,11 +180,7 @@ object SVMWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. */ - def train( - input: RDD[LabeledPoint], - numIterations: Int) - : SVMModel = - { + def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 1.0, 1.0) } @@ -181,7 +192,8 @@ object SVMWithSGD { val sc = new SparkContext(args(0), "SVM") val data = MLUtils.loadLabeledData(sc, args(1)) val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index b412738e3f00a..90cf8525df523 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -21,6 +21,7 @@ import scala.collection.mutable.ArrayBuffer import breeze.linalg.{DenseVector => BDV, Vector => BV, norm => breezeNorm} +import org.apache.spark.annotation.Experimental import org.apache.spark.{Logging, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -37,13 +38,17 @@ import org.apache.spark.util.random.XORShiftRandom * to it should be cached by the user. */ class KMeans private ( - var k: Int, - var maxIterations: Int, - var runs: Int, - var initializationMode: String, - var initializationSteps: Int, - var epsilon: Double) - extends Serializable with Logging { + private var k: Int, + private var maxIterations: Int, + private var runs: Int, + private var initializationMode: String, + private var initializationSteps: Int, + private var epsilon: Double) extends Serializable with Logging { + + /** + * Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1, + * initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4}. + */ def this() = this(2, 20, 1, KMeans.K_MEANS_PARALLEL, 5, 1e-4) /** Set the number of clusters to create (k). Default: 2. */ @@ -72,6 +77,7 @@ class KMeans private ( } /** + * :: Experimental :: * Set the number of runs of the algorithm to execute in parallel. We initialize the algorithm * this many times with random starting conditions (configured by the initialization mode), then * return the best clustering found over any run. Default: 1. @@ -317,8 +323,8 @@ object KMeans { data: RDD[Vector], k: Int, maxIterations: Int, - runs: Int = 1, - initializationMode: String = K_MEANS_PARALLEL): KMeansModel = { + runs: Int, + initializationMode: String): KMeansModel = { new KMeans().setK(k) .setMaxIterations(maxIterations) .setRuns(runs) @@ -326,6 +332,27 @@ object KMeans { .run(data) } + /** + * Trains a k-means model using specified parameters and the default values for unspecified. + */ + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int): KMeansModel = { + train(data, k, maxIterations, 1, K_MEANS_PARALLEL) + } + + /** + * Trains a k-means model using specified parameters and the default values for unspecified. + */ + def train( + data: RDD[Vector], + k: Int, + maxIterations: Int, + runs: Int): KMeansModel = { + train(data, k, maxIterations, runs, K_MEANS_PARALLEL) + } + /** * Returns the index of the closest center to the given point, as well as the squared distance. */ @@ -370,6 +397,7 @@ object KMeans { MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm) } + @Experimental def main(args: Array[String]) { if (args.length < 4) { println("Usage: KMeans []") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala new file mode 100644 index 0000000000000..7858ec602483f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.rdd.RDDFunctions._ + +/** + * Computes the area under the curve (AUC) using the trapezoidal rule. + */ +private[evaluation] object AreaUnderCurve { + + /** + * Uses the trapezoidal rule to compute the area under the line connecting the two input points. + * @param points two 2D points stored in Seq + */ + private def trapezoid(points: Seq[(Double, Double)]): Double = { + require(points.length == 2) + val x = points.head + val y = points.last + (y._1 - x._1) * (y._2 + x._2) / 2.0 + } + + /** + * Returns the area under the given curve. + * + * @param curve a RDD of ordered 2D points stored in pairs representing a curve + */ + def of(curve: RDD[(Double, Double)]): Double = { + curve.sliding(2).aggregate(0.0)( + seqOp = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + combOp = _ + _ + ) + } + + /** + * Returns the area under the given curve. + * + * @param curve an iterator over ordered 2D points stored in pairs representing a curve + */ + def of(curve: Iterable[(Double, Double)]): Double = { + curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)( + seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), + combop = _ + _ + ) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala new file mode 100644 index 0000000000000..562663ad36b40 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricComputers.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +/** + * Trait for a binary classification evaluation metric computer. + */ +private[evaluation] trait BinaryClassificationMetricComputer extends Serializable { + def apply(c: BinaryConfusionMatrix): Double +} + +/** Precision. */ +private[evaluation] object Precision extends BinaryClassificationMetricComputer { + override def apply(c: BinaryConfusionMatrix): Double = + c.numTruePositives.toDouble / (c.numTruePositives + c.numFalsePositives) +} + +/** False positive rate. */ +private[evaluation] object FalsePositiveRate extends BinaryClassificationMetricComputer { + override def apply(c: BinaryConfusionMatrix): Double = + c.numFalsePositives.toDouble / c.numNegatives +} + +/** Recall. */ +private[evaluation] object Recall extends BinaryClassificationMetricComputer { + override def apply(c: BinaryConfusionMatrix): Double = + c.numTruePositives.toDouble / c.numPositives +} + +/** + * F-Measure. + * @param beta the beta constant in F-Measure + * @see http://en.wikipedia.org/wiki/F1_score + */ +private[evaluation] case class FMeasure(beta: Double) extends BinaryClassificationMetricComputer { + private val beta2 = beta * beta + override def apply(c: BinaryConfusionMatrix): Double = { + val precision = Precision(c) + val recall = Recall(c) + (1.0 + beta2) * (precision * recall) / (beta2 * precision + recall) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala new file mode 100644 index 0000000000000..ed7b0fc943367 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetrics.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +import org.apache.spark.rdd.{UnionRDD, RDD} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.evaluation.AreaUnderCurve +import org.apache.spark.Logging + +/** + * Implementation of [[org.apache.spark.mllib.evaluation.binary.BinaryConfusionMatrix]]. + * + * @param count label counter for labels with scores greater than or equal to the current score + * @param totalCount label counter for all labels + */ +private case class BinaryConfusionMatrixImpl( + count: LabelCounter, + totalCount: LabelCounter) extends BinaryConfusionMatrix with Serializable { + + /** number of true positives */ + override def numTruePositives: Long = count.numPositives + + /** number of false positives */ + override def numFalsePositives: Long = count.numNegatives + + /** number of false negatives */ + override def numFalseNegatives: Long = totalCount.numPositives - count.numPositives + + /** number of true negatives */ + override def numTrueNegatives: Long = totalCount.numNegatives - count.numNegatives + + /** number of positives */ + override def numPositives: Long = totalCount.numPositives + + /** number of negatives */ + override def numNegatives: Long = totalCount.numNegatives +} + +/** + * Evaluator for binary classification. + * + * @param scoreAndLabels an RDD of (score, label) pairs. + */ +class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) + extends Serializable with Logging { + + private lazy val ( + cumulativeCounts: RDD[(Double, LabelCounter)], + confusions: RDD[(Double, BinaryConfusionMatrix)]) = { + // Create a bin for each distinct score value, count positives and negatives within each bin, + // and then sort by score values in descending order. + val counts = scoreAndLabels.combineByKey( + createCombiner = (label: Double) => new LabelCounter(0L, 0L) += label, + mergeValue = (c: LabelCounter, label: Double) => c += label, + mergeCombiners = (c1: LabelCounter, c2: LabelCounter) => c1 += c2 + ).sortByKey(ascending = false) + val agg = counts.values.mapPartitions({ iter => + val agg = new LabelCounter() + iter.foreach(agg += _) + Iterator(agg) + }, preservesPartitioning = true).collect() + val partitionwiseCumulativeCounts = + agg.scanLeft(new LabelCounter())((agg: LabelCounter, c: LabelCounter) => agg.clone() += c) + val totalCount = partitionwiseCumulativeCounts.last + logInfo(s"Total counts: $totalCount") + val cumulativeCounts = counts.mapPartitionsWithIndex( + (index: Int, iter: Iterator[(Double, LabelCounter)]) => { + val cumCount = partitionwiseCumulativeCounts(index) + iter.map { case (score, c) => + cumCount += c + (score, cumCount.clone()) + } + }, preservesPartitioning = true) + cumulativeCounts.persist() + val confusions = cumulativeCounts.map { case (score, cumCount) => + (score, BinaryConfusionMatrixImpl(cumCount, totalCount).asInstanceOf[BinaryConfusionMatrix]) + } + (cumulativeCounts, confusions) + } + + /** Unpersist intermediate RDDs used in the computation. */ + def unpersist() { + cumulativeCounts.unpersist() + } + + /** Returns thresholds in descending order. */ + def thresholds(): RDD[Double] = cumulativeCounts.map(_._1) + + /** + * Returns the receiver operating characteristic (ROC) curve, + * which is an RDD of (false positive rate, true positive rate) + * with (0.0, 0.0) prepended and (1.0, 1.0) appended to it. + * @see http://en.wikipedia.org/wiki/Receiver_operating_characteristic + */ + def roc(): RDD[(Double, Double)] = { + val rocCurve = createCurve(FalsePositiveRate, Recall) + val sc = confusions.context + val first = sc.makeRDD(Seq((0.0, 0.0)), 1) + val last = sc.makeRDD(Seq((1.0, 1.0)), 1) + new UnionRDD[(Double, Double)](sc, Seq(first, rocCurve, last)) + } + + /** + * Computes the area under the receiver operating characteristic (ROC) curve. + */ + def areaUnderROC(): Double = AreaUnderCurve.of(roc()) + + /** + * Returns the precision-recall curve, which is an RDD of (recall, precision), + * NOT (precision, recall), with (0.0, 1.0) prepended to it. + * @see http://en.wikipedia.org/wiki/Precision_and_recall + */ + def pr(): RDD[(Double, Double)] = { + val prCurve = createCurve(Recall, Precision) + val sc = confusions.context + val first = sc.makeRDD(Seq((0.0, 1.0)), 1) + first.union(prCurve) + } + + /** + * Computes the area under the precision-recall curve. + */ + def areaUnderPR(): Double = AreaUnderCurve.of(pr()) + + /** + * Returns the (threshold, F-Measure) curve. + * @param beta the beta factor in F-Measure computation. + * @return an RDD of (threshold, F-Measure) pairs. + * @see http://en.wikipedia.org/wiki/F1_score + */ + def fMeasureByThreshold(beta: Double): RDD[(Double, Double)] = createCurve(FMeasure(beta)) + + /** Returns the (threshold, F-Measure) curve with beta = 1.0. */ + def fMeasureByThreshold(): RDD[(Double, Double)] = fMeasureByThreshold(1.0) + + /** Returns the (threshold, precision) curve. */ + def precisionByThreshold(): RDD[(Double, Double)] = createCurve(Precision) + + /** Returns the (threshold, recall) curve. */ + def recallByThreshold(): RDD[(Double, Double)] = createCurve(Recall) + + /** Creates a curve of (threshold, metric). */ + private def createCurve(y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = { + confusions.map { case (s, c) => + (s, y(c)) + } + } + + /** Creates a curve of (metricX, metricY). */ + private def createCurve( + x: BinaryClassificationMetricComputer, + y: BinaryClassificationMetricComputer): RDD[(Double, Double)] = { + confusions.map { case (_, c) => + (x(c), y(c)) + } + } +} + +/** + * A counter for positives and negatives. + * + * @param numPositives number of positive labels + * @param numNegatives number of negative labels + */ +private class LabelCounter( + var numPositives: Long = 0L, + var numNegatives: Long = 0L) extends Serializable { + + /** Processes a label. */ + def +=(label: Double): LabelCounter = { + // Though we assume 1.0 for positive and 0.0 for negative, the following check will handle + // -1.0 for negative as well. + if (label > 0.5) numPositives += 1L else numNegatives += 1L + this + } + + /** Merges another counter. */ + def +=(other: LabelCounter): LabelCounter = { + numPositives += other.numPositives + numNegatives += other.numNegatives + this + } + + override def clone: LabelCounter = { + new LabelCounter(numPositives, numNegatives) + } + + override def toString: String = s"{numPos: $numPositives, numNeg: $numNegatives}" +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala new file mode 100644 index 0000000000000..75a75b216002a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/binary/BinaryConfusionMatrix.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +/** + * Trait for a binary confusion matrix. + */ +private[evaluation] trait BinaryConfusionMatrix { + /** number of true positives */ + def numTruePositives: Long + + /** number of false positives */ + def numFalsePositives: Long + + /** number of false negatives */ + def numFalseNegatives: Long + + /** number of true negatives */ + def numTrueNegatives: Long + + /** number of positives */ + def numPositives: Long = numTruePositives + numFalseNegatives + + /** number of negatives */ + def numNegatives: Long = numFalsePositives + numTrueNegatives +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala new file mode 100644 index 0000000000000..b11ba5d30fbd3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import breeze.linalg.{Matrix => BM, DenseMatrix => BDM} + +/** + * Trait for a local matrix. + */ +trait Matrix extends Serializable { + + /** Number of rows. */ + def numRows: Int + + /** Number of columns. */ + def numCols: Int + + /** Converts to a dense array in column major. */ + def toArray: Array[Double] + + /** Converts to a breeze matrix. */ + private[mllib] def toBreeze: BM[Double] + + /** Gets the (i, j)-th element. */ + private[mllib] def apply(i: Int, j: Int): Double = toBreeze(i, j) + + override def toString: String = toBreeze.toString() +} + +/** + * Column-majored dense matrix. + * The entry values are stored in a single array of doubles with columns listed in sequence. + * For example, the following matrix + * {{{ + * 1.0 2.0 + * 3.0 4.0 + * 5.0 6.0 + * }}} + * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param values matrix entries in column major + */ +class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix { + + require(values.length == numRows * numCols) + + override def toArray: Array[Double] = values + + private[mllib] override def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) +} + +/** + * Factory methods for [[org.apache.spark.mllib.linalg.Matrix]]. + */ +object Matrices { + + /** + * Creates a column-majored dense matrix. + * + * @param numRows number of rows + * @param numCols number of columns + * @param values matrix entries in column major + */ + def dense(numRows: Int, numCols: Int, values: Array[Double]): Matrix = { + new DenseMatrix(numRows, numCols, values) + } + + /** + * Creates a Matrix instance from a breeze matrix. + * @param breeze a breeze matrix + * @return a Matrix instance + */ + private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = { + breeze match { + case dm: BDM[Double] => + require(dm.majorStride == dm.rows, + "Do not support stride size different from the number of rows.") + new DenseMatrix(dm.rows, dm.cols, dm.data) + case _ => + throw new UnsupportedOperationException( + s"Do not support conversion from type ${breeze.getClass.getName}.") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/PCA.scala deleted file mode 100644 index fe5b3f6c7e463..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/PCA.scala +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.linalg - -import org.apache.spark.rdd.RDD - - -import org.jblas.DoubleMatrix - - -/** - * Class used to obtain principal components - */ -class PCA { - private var k = 1 - - /** - * Set the number of top-k principle components to return - */ - def setK(k: Int): PCA = { - this.k = k - this - } - - /** - * Compute PCA using the current set parameters - */ - def compute(matrix: TallSkinnyDenseMatrix): Array[Array[Double]] = { - computePCA(matrix) - } - - /** - * Compute PCA using the parameters currently set - * See computePCA() for more details - */ - def compute(matrix: RDD[Array[Double]]): Array[Array[Double]] = { - computePCA(matrix) - } - - /** - * Computes the top k principal component coefficients for the m-by-n data matrix X. - * Rows of X correspond to observations and columns correspond to variables. - * The coefficient matrix is n-by-k. Each column of coeff contains coefficients - * for one principal component, and the columns are in descending - * order of component variance. - * This function centers the data and uses the - * singular value decomposition (SVD) algorithm. - * - * @param matrix dense matrix to perform PCA on - * @return An nxk matrix with principal components in columns. Columns are inner arrays - */ - private def computePCA(matrix: TallSkinnyDenseMatrix): Array[Array[Double]] = { - val m = matrix.m - val n = matrix.n - - if (m <= 0 || n <= 0) { - throw new IllegalArgumentException("Expecting a well-formed matrix: m=$m n=$n") - } - - computePCA(matrix.rows.map(_.data)) - } - - /** - * Computes the top k principal component coefficients for the m-by-n data matrix X. - * Rows of X correspond to observations and columns correspond to variables. - * The coefficient matrix is n-by-k. Each column of coeff contains coefficients - * for one principal component, and the columns are in descending - * order of component variance. - * This function centers the data and uses the - * singular value decomposition (SVD) algorithm. - * - * @param matrix dense matrix to perform pca on - * @return An nxk matrix of principal components - */ - private def computePCA(matrix: RDD[Array[Double]]): Array[Array[Double]] = { - val n = matrix.first.size - - // compute column sums and normalize matrix - val colSumsTemp = matrix.map((_, 1)).fold((Array.ofDim[Double](n), 0)) { - (a, b) => - val am = new DoubleMatrix(a._1) - val bm = new DoubleMatrix(b._1) - am.addi(bm) - (a._1, a._2 + b._2) - } - - val m = colSumsTemp._2 - val colSums = colSumsTemp._1.map(x => x / m) - - val data = matrix.map { - x => - val row = Array.ofDim[Double](n) - var i = 0 - while (i < n) { - row(i) = x(i) - colSums(i) - i += 1 - } - row - } - - val (u, s, v) = new SVD().setK(k).compute(data) - v - } -} - diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala deleted file mode 100644 index 3e7cc648d1d37..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SVD.scala +++ /dev/null @@ -1,397 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.linalg - -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -import org.jblas.{DoubleMatrix, Singular, MatrixFunctions} - -/** - * Class used to obtain singular value decompositions - */ -class SVD { - private var k = 1 - private var computeU = true - - // All singular values smaller than rCond * sigma(0) - // are treated as zero, where sigma(0) is the largest singular value. - private var rCond = 1e-9 - - /** - * Set the number of top-k singular vectors to return - */ - def setK(k: Int): SVD = { - this.k = k - this - } - - /** - * Sets the reciprocal condition number (rCond). All singular values - * smaller than rCond * sigma(0) are treated as zero, - * where sigma(0) is the largest singular value. - */ - def setReciprocalConditionNumber(smallS: Double): SVD = { - this.rCond = smallS - this - } - - /** - * Should U be computed? - */ - def setComputeU(compU: Boolean): SVD = { - this.computeU = compU - this - } - - /** - * Compute SVD using the current set parameters - */ - def compute(matrix: TallSkinnyDenseMatrix): TallSkinnyMatrixSVD = { - denseSVD(matrix) - } - - /** - * Compute SVD using the current set parameters - * Returns (U, S, V) such that A = USV^T - * U is a row-by-row dense matrix - * S is a simple double array of singular values - * V is a 2d array matrix - * See [[denseSVD]] for more documentation - */ - def compute(matrix: RDD[Array[Double]]): - (RDD[Array[Double]], Array[Double], Array[Array[Double]]) = { - denseSVD(matrix) - } - - /** - * See full paramter definition of sparseSVD for more description. - * - * @param matrix sparse matrix to factorize - * @return Three sparse matrices: U, S, V such that A = USV^T - */ - def compute(matrix: SparseMatrix): MatrixSVD = { - sparseSVD(matrix) - } - - /** - * Singular Value Decomposition for Tall and Skinny matrices. - * Given an m x n matrix A, this will compute matrices U, S, V such that - * A = U * S * V' - * - * There is no restriction on m, but we require n^2 doubles to fit in memory. - * Further, n should be less than m. - * - * The decomposition is computed by first computing A'A = V S^2 V', - * computing svd locally on that (since n x n is small), - * from which we recover S and V. - * Then we compute U via easy matrix multiplication - * as U = A * V * S^-1 - * - * Only the k largest singular values and associated vectors are found. - * If there are k such values, then the dimensions of the return will be: - * - * S is k x k and diagonal, holding the singular values on diagonal - * U is m x k and satisfies U'U = eye(k) - * V is n x k and satisfies V'V = eye(k) - * - * @param matrix dense matrix to factorize - * @return See [[TallSkinnyMatrixSVD]] for the output matrices and arrays - */ - private def denseSVD(matrix: TallSkinnyDenseMatrix): TallSkinnyMatrixSVD = { - val m = matrix.m - val n = matrix.n - - if (m < n || m <= 0 || n <= 0) { - throw new IllegalArgumentException("Expecting a tall and skinny matrix m=$m n=$n") - } - - if (k < 1 || k > n) { - throw new IllegalArgumentException("Request up to n singular values n=$n k=$k") - } - - val rowIndices = matrix.rows.map(_.i) - - // compute SVD - val (u, sigma, v) = denseSVD(matrix.rows.map(_.data)) - - if (computeU) { - // prep u for returning - val retU = TallSkinnyDenseMatrix( - u.zip(rowIndices).map { - case (row, i) => MatrixRow(i, row) - }, - m, - k) - - TallSkinnyMatrixSVD(retU, sigma, v) - } else { - TallSkinnyMatrixSVD(null, sigma, v) - } - } - - /** - * Singular Value Decomposition for Tall and Skinny matrices. - * Given an m x n matrix A, this will compute matrices U, S, V such that - * A = U * S * V' - * - * There is no restriction on m, but we require n^2 doubles to fit in memory. - * Further, n should be less than m. - * - * The decomposition is computed by first computing A'A = V S^2 V', - * computing svd locally on that (since n x n is small), - * from which we recover S and V. - * Then we compute U via easy matrix multiplication - * as U = A * V * S^-1 - * - * Only the k largest singular values and associated vectors are found. - * If there are k such values, then the dimensions of the return will be: - * - * S is k x k and diagonal, holding the singular values on diagonal - * U is m x k and satisfies U'U = eye(k) - * V is n x k and satisfies V'V = eye(k) - * - * The return values are as lean as possible: an RDD of rows for U, - * a simple array for sigma, and a dense 2d matrix array for V - * - * @param matrix dense matrix to factorize - * @return Three matrices: U, S, V such that A = USV^T - */ - private def denseSVD(matrix: RDD[Array[Double]]): - (RDD[Array[Double]], Array[Double], Array[Array[Double]]) = { - val n = matrix.first.size - - if (k < 1 || k > n) { - throw new IllegalArgumentException( - "Request up to n singular values k=$k n=$n") - } - - // Compute A^T A - val fullata = matrix.mapPartitions { - iter => - val localATA = Array.ofDim[Double](n, n) - while (iter.hasNext) { - val row = iter.next() - var i = 0 - while (i < n) { - var j = 0 - while (j < n) { - localATA(i)(j) += row(i) * row(j) - j += 1 - } - i += 1 - } - } - Iterator(localATA) - }.fold(Array.ofDim[Double](n, n)) { - (a, b) => - var i = 0 - while (i < n) { - var j = 0 - while (j < n) { - a(i)(j) += b(i)(j) - j += 1 - } - i += 1 - } - a - } - - // Construct jblas A^T A locally - val ata = new DoubleMatrix(fullata) - - // Since A^T A is small, we can compute its SVD directly - val svd = Singular.sparseSVD(ata) - val V = svd(0) - val sigmas = MatrixFunctions.sqrt(svd(1)).toArray.filter(x => x / svd(1).get(0) > rCond) - - val sk = Math.min(k, sigmas.size) - val sigma = sigmas.take(sk) - - // prepare V for returning - val retV = Array.tabulate(n, sk)((i, j) => V.get(i, j)) - - if (computeU) { - // Compute U as U = A V S^-1 - // Compute VS^-1 - val vsinv = new DoubleMatrix(Array.tabulate(n, sk)((i, j) => V.get(i, j) / sigma(j))) - val retU = matrix.map { - x => - val v = new DoubleMatrix(Array(x)) - v.mmul(vsinv).data - } - (retU, sigma, retV) - } else { - (null, sigma, retV) - } - } - - /** - * Singular Value Decomposition for Tall and Skinny sparse matrices. - * Given an m x n matrix A, this will compute matrices U, S, V such that - * A = U * S * V' - * - * There is no restriction on m, but we require O(n^2) doubles to fit in memory. - * Further, n should be less than m. - * - * The decomposition is computed by first computing A'A = V S^2 V', - * computing svd locally on that (since n x n is small), - * from which we recover S and V. - * Then we compute U via easy matrix multiplication - * as U = A * V * S^-1 - * - * Only the k largest singular values and associated vectors are found. - * If there are k such values, then the dimensions of the return will be: - * - * S is k x k and diagonal, holding the singular values on diagonal - * U is m x k and satisfies U'U = eye(k) - * V is n x k and satisfies V'V = eye(k) - * - * All input and output is expected in sparse matrix format, 0-indexed - * as tuples of the form ((i,j),value) all in RDDs using the - * SparseMatrix class - * - * @param matrix sparse matrix to factorize - * @return Three sparse matrices: U, S, V such that A = USV^T - */ - private def sparseSVD(matrix: SparseMatrix): MatrixSVD = { - val data = matrix.data - val m = matrix.m - val n = matrix.n - - if (m < n || m <= 0 || n <= 0) { - throw new IllegalArgumentException("Expecting a tall and skinny matrix") - } - - if (k < 1 || k > n) { - throw new IllegalArgumentException("Must request up to n singular values") - } - - // Compute A^T A, assuming rows are sparse enough to fit in memory - val rows = data.map(entry => - (entry.i, (entry.j, entry.mval))).groupByKey() - val emits = rows.flatMap { - case (rowind, cols) => - cols.flatMap { - case (colind1, mval1) => - cols.map { - case (colind2, mval2) => - ((colind1, colind2), mval1 * mval2) - } - } - }.reduceByKey(_ + _) - - // Construct jblas A^T A locally - val ata = DoubleMatrix.zeros(n, n) - for (entry <- emits.collect()) { - ata.put(entry._1._1, entry._1._2, entry._2) - } - - // Since A^T A is small, we can compute its SVD directly - val svd = Singular.sparseSVD(ata) - val V = svd(0) - // This will be updated to rcond - val sigmas = MatrixFunctions.sqrt(svd(1)).toArray.filter(x => x > 1e-9) - - if (sigmas.size < k) { - throw new Exception("Not enough singular values to return k=" + k + " s=" + sigmas.size) - } - - val sigma = sigmas.take(k) - - val sc = data.sparkContext - - // prepare V for returning - val retVdata = sc.makeRDD( - Array.tabulate(V.rows, sigma.length) { - (i, j) => - MatrixEntry(i, j, V.get(i, j)) - }.flatten) - val retV = SparseMatrix(retVdata, V.rows, sigma.length) - - val retSdata = sc.makeRDD(Array.tabulate(sigma.length) { - x => MatrixEntry(x, x, sigma(x)) - }) - - val retS = SparseMatrix(retSdata, sigma.length, sigma.length) - - // Compute U as U = A V S^-1 - // turn V S^-1 into an RDD as a sparse matrix - val vsirdd = sc.makeRDD(Array.tabulate(V.rows, sigma.length) { - (i, j) => ((i, j), V.get(i, j) / sigma(j)) - }.flatten) - - if (computeU) { - // Multiply A by VS^-1 - val aCols = data.map(entry => (entry.j, (entry.i, entry.mval))) - val bRows = vsirdd.map(entry => (entry._1._1, (entry._1._2, entry._2))) - val retUdata = aCols.join(bRows).map { - case (key, ((rowInd, rowVal), (colInd, colVal))) => - ((rowInd, colInd), rowVal * colVal) - }.reduceByKey(_ + _).map { - case ((row, col), mval) => MatrixEntry(row, col, mval) - } - - val retU = SparseMatrix(retUdata, m, sigma.length) - MatrixSVD(retU, retS, retV) - } else { - MatrixSVD(null, retS, retV) - } - } -} - -/** - * Top-level methods for calling sparse Singular Value Decomposition - * NOTE: All matrices are 0-indexed - */ -object SVD { - def main(args: Array[String]) { - if (args.length < 8) { - println("Usage: SVD " + - " ") - System.exit(1) - } - - val (master, inputFile, m, n, k, output_u, output_s, output_v) = - (args(0), args(1), args(2).toInt, args(3).toInt, - args(4).toInt, args(5), args(6), args(7)) - - val sc = new SparkContext(master, "SVD") - - val rawData = sc.textFile(inputFile) - val data = rawData.map { - line => - val parts = line.split(',') - MatrixEntry(parts(0).toInt, parts(1).toInt, parts(2).toDouble) - } - - val decomposed = new SVD().setK(k).compute(SparseMatrix(data, m, n)) - val u = decomposed.U.data - val s = decomposed.S.data - val v = decomposed.V.data - - println("Computed " + s.collect().length + " singular values and vectors") - u.saveAsTextFile(output_u) - s.saveAsTextFile(output_s) - v.saveAsTextFile(output_v) - System.exit(0) - } -} - - diff --git a/core/src/main/scala/org/apache/spark/ui/Page.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala similarity index 81% rename from core/src/main/scala/org/apache/spark/ui/Page.scala rename to mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala index b2a069a37552d..46b105457430c 100644 --- a/core/src/main/scala/org/apache/spark/ui/Page.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/SingularValueDecomposition.scala @@ -15,8 +15,7 @@ * limitations under the License. */ -package org.apache.spark.ui +package org.apache.spark.mllib.linalg -private[spark] object Page extends Enumeration { - val Stages, Storage, Environment, Executors = Value -} +/** Represents singular value decomposition (SVD) factors. */ +case class SingularValueDecomposition[UType, VType](U: UType, s: Vector, V: VType) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyMatrixSVD.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyMatrixSVD.scala deleted file mode 100644 index b3a450e92394e..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyMatrixSVD.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.linalg - -/** - * Class that represents the singular value decomposition of a matrix - * - * @param U such that A = USV^T is a TallSkinnyDenseMatrix - * @param S such that A = USV^T is a simple double array - * @param V such that A = USV^T, V is a 2d array matrix that holds - * singular vectors in columns. Columns are inner arrays - * i.e. V(i)(j) is standard math notation V_{ij} - */ -case class TallSkinnyMatrixSVD(val U: TallSkinnyDenseMatrix, - val S: Array[Double], - val V: Array[Array[Double]]) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 01c1501548f87..99a849f1c66b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -54,15 +54,23 @@ trait Vector extends Serializable { * Converts the instance to a breeze vector. */ private[mllib] def toBreeze: BV[Double] + + /** + * Gets the value of the ith element. + * @param i index + */ + private[mllib] def apply(i: Int): Double = toBreeze(i) } /** * Factory methods for [[org.apache.spark.mllib.linalg.Vector]]. + * We don't use the name `Vector` because Scala imports + * [[scala.collection.immutable.Vector]] by default. */ object Vectors { /** - * Creates a dense vector. + * Creates a dense vector from its values. */ @varargs def dense(firstValue: Double, otherValues: Double*): Vector = @@ -145,25 +153,28 @@ class DenseVector(val values: Array[Double]) extends Vector { override def toArray: Array[Double] = values private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values) + + override def apply(i: Int) = values(i) } /** * A sparse vector represented by an index array and an value array. * - * @param n size of the vector. + * @param size size of the vector. * @param indices index array, assume to be strictly increasing. * @param values value array, must have the same length as the index array. */ -class SparseVector(val n: Int, val indices: Array[Int], val values: Array[Double]) extends Vector { - - override def size: Int = n +class SparseVector( + override val size: Int, + val indices: Array[Int], + val values: Array[Double]) extends Vector { override def toString: String = { - "(" + n + "," + indices.zip(values).mkString("[", "," ,"]") + ")" + "(" + size + "," + indices.zip(values).mkString("[", "," ,"]") + ")" } override def toArray: Array[Double] = { - val data = new Array[Double](n) + val data = new Array[Double](size) var i = 0 val nnz = indices.length while (i < nnz) { @@ -173,5 +184,5 @@ class SparseVector(val n: Int, val indices: Array[Int], val values: Array[Double data } - private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, n) + private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala new file mode 100644 index 0000000000000..56b8fdcda66eb --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors + +/** + * Represents an entry in an distributed matrix. + * @param i row index + * @param j column index + * @param value value of the entry + */ +case class MatrixEntry(i: Long, j: Long, value: Double) + +/** + * :: Experimental :: + * Represents a matrix in coordinate format. + * + * @param entries matrix entries + * @param nRows number of rows. A non-positive value means unknown, and then the number of rows will + * be determined by the max row index plus one. + * @param nCols number of columns. A non-positive value means unknown, and then the number of + * columns will be determined by the max column index plus one. + */ +@Experimental +class CoordinateMatrix( + val entries: RDD[MatrixEntry], + private var nRows: Long, + private var nCols: Long) extends DistributedMatrix { + + /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + def this(entries: RDD[MatrixEntry]) = this(entries, 0L, 0L) + + /** Gets or computes the number of columns. */ + override def numCols(): Long = { + if (nCols <= 0L) { + computeSize() + } + nCols + } + + /** Gets or computes the number of rows. */ + override def numRows(): Long = { + if (nRows <= 0L) { + computeSize() + } + nRows + } + + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + def toIndexedRowMatrix(): IndexedRowMatrix = { + val nl = numCols() + if (nl > Int.MaxValue) { + sys.error(s"Cannot convert to a row-oriented format because the number of columns $nl is " + + "too large.") + } + val n = nl.toInt + val indexedRows = entries.map(entry => (entry.i, (entry.j.toInt, entry.value))) + .groupByKey() + .map { case (i, vectorEntries) => + IndexedRow(i, Vectors.sparse(n, vectorEntries.toSeq)) + } + new IndexedRowMatrix(indexedRows, numRows(), n) + } + + /** + * Converts to RowMatrix, dropping row indices after grouping by row index. + * The number of columns must be within the integer range. + */ + def toRowMatrix(): RowMatrix = { + toIndexedRowMatrix().toRowMatrix() + } + + /** Determines the size by computing the max row/column index. */ + private def computeSize() { + // Reduce will throw an exception if `entries` is empty. + val (m1, n1) = entries.map(entry => (entry.i, entry.j)).reduce { case ((i1, j1), (i2, j2)) => + (math.max(i1, i2), math.max(j1, j2)) + } + // There may be empty columns at the very right and empty rows at the very bottom. + nRows = math.max(nRows, m1 + 1L) + nCols = math.max(nCols, n1 + 1L) + } + + /** Collects data and assembles a local matrix. */ + private[mllib] override def toBreeze(): BDM[Double] = { + val m = numRows().toInt + val n = numCols().toInt + val mat = BDM.zeros[Double](m, n) + entries.collect().foreach { case MatrixEntry(i, j, value) => + mat(i.toInt, j.toInt) = value + } + mat + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala new file mode 100644 index 0000000000000..a0e26ce3bc465 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/DistributedMatrix.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import breeze.linalg.{DenseMatrix => BDM} + +/** + * Represents a distributively stored matrix backed by one or more RDDs. + */ +trait DistributedMatrix extends Serializable { + + /** Gets or computes the number of rows. */ + def numRows(): Long + + /** Gets or computes the number of columns. */ + def numCols(): Long + + /** Collects data and assembles a local dense breeze matrix (for test only). */ + private[mllib] def toBreeze(): BDM[Double] +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala new file mode 100644 index 0000000000000..132b3af72d9ce --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.SingularValueDecomposition + +/** + * :: Experimental :: + * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. + */ +@Experimental +case class IndexedRow(index: Long, vector: Vector) + +/** + * :: Experimental :: + * Represents a row-oriented [[org.apache.spark.mllib.linalg.distributed.DistributedMatrix]] with + * indexed rows. + * + * @param rows indexed rows of this matrix + * @param nRows number of rows. A non-positive value means unknown, and then the number of rows will + * be determined by the max row index plus one. + * @param nCols number of columns. A non-positive value means unknown, and then the number of + * columns will be determined by the size of the first row. + */ +@Experimental +class IndexedRowMatrix( + val rows: RDD[IndexedRow], + private var nRows: Long, + private var nCols: Int) extends DistributedMatrix { + + /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + def this(rows: RDD[IndexedRow]) = this(rows, 0L, 0) + + override def numCols(): Long = { + if (nCols <= 0) { + // Calling `first` will throw an exception if `rows` is empty. + nCols = rows.first().vector.size + } + nCols + } + + override def numRows(): Long = { + if (nRows <= 0L) { + // Reduce will throw an exception if `rows` is empty. + nRows = rows.map(_.index).reduce(math.max) + 1L + } + nRows + } + + /** + * Drops row indices and converts this matrix to a + * [[org.apache.spark.mllib.linalg.distributed.RowMatrix]]. + */ + def toRowMatrix(): RowMatrix = { + new RowMatrix(rows.map(_.vector), 0L, nCols) + } + + /** + * Computes the singular value decomposition of this matrix. + * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. + * + * There is no restriction on m, but we require `n^2` doubles to fit in memory. + * Further, n should be less than m. + + * The decomposition is computed by first computing A'A = V S^2 V', + * computing svd locally on that (since n x n is small), from which we recover S and V. + * Then we compute U via easy matrix multiplication as U = A * (V * S^-1). + * Note that this approach requires `O(n^3)` time on the master node. + * + * At most k largest non-zero singular values and associated vectors are returned. + * If there are k such values, then the dimensions of the return will be: + * + * U is an [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]] of size m x k that + * satisfies U'U = eye(k), + * s is a Vector of size k, holding the singular values in descending order, + * and V is a local Matrix of size n x k that satisfies V'V = eye(k). + * + * @param k number of singular values to keep. We might return less than k if there are + * numerically zero singular values. See rCond. + * @param computeU whether to compute U + * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) + * are treated as zero, where sigma(0) is the largest singular value. + * @return SingularValueDecomposition(U, s, V) + */ + def computeSVD( + k: Int, + computeU: Boolean = false, + rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = { + val indices = rows.map(_.index) + val svd = toRowMatrix().computeSVD(k, computeU, rCond) + val U = if (computeU) { + val indexedRows = indices.zip(svd.U.rows).map { case (i, v) => + IndexedRow(i, v) + } + new IndexedRowMatrix(indexedRows, nRows, nCols) + } else { + null + } + SingularValueDecomposition(U, svd.s, svd.V) + } + + /** + * Multiply this matrix by a local matrix on the right. + * + * @param B a local matrix whose number of rows must match the number of columns of this matrix + * @return an IndexedRowMatrix representing the product, which preserves partitioning + */ + def multiply(B: Matrix): IndexedRowMatrix = { + val mat = toRowMatrix().multiply(B) + val indexedRows = rows.map(_.index).zip(mat.rows).map { case (i, v) => + IndexedRow(i, v) + } + new IndexedRowMatrix(indexedRows, nRows, nCols) + } + + /** + * Computes the Gramian matrix `A^T A`. + */ + def computeGramianMatrix(): Matrix = { + toRowMatrix().computeGramianMatrix() + } + + private[mllib] override def toBreeze(): BDM[Double] = { + val m = numRows().toInt + val n = numCols().toInt + val mat = BDM.zeros[Double](m, n) + rows.collect().foreach { case IndexedRow(rowIndex, vector) => + val i = rowIndex.toInt + vector.toBreeze.activeIterator.foreach { case (j, v) => + mat(i, j) = v + } + } + mat + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala new file mode 100644 index 0000000000000..0c0afcd9ec0d7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import java.util + +import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd} +import breeze.numerics.{sqrt => brzSqrt} +import com.github.fommil.netlib.BLAS.{getInstance => blas} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg._ +import org.apache.spark.rdd.RDD +import org.apache.spark.Logging +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary + +/** + * Column statistics aggregator implementing + * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]] + * together with add() and merge() function. + * A numerically stable algorithm is implemented to compute sample mean and variance: + *[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]]. + * Zero elements (including explicit zero values) are skipped when calling add() and merge(), + * to have time complexity O(nnz) instead of O(n) for each column. + */ +private class ColumnStatisticsAggregator(private val n: Int) + extends MultivariateStatisticalSummary with Serializable { + + private val currMean: BDV[Double] = BDV.zeros[Double](n) + private val currM2n: BDV[Double] = BDV.zeros[Double](n) + private var totalCnt = 0.0 + private val nnz: BDV[Double] = BDV.zeros[Double](n) + private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue) + private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue) + + override def mean: Vector = { + val realMean = BDV.zeros[Double](n) + var i = 0 + while (i < n) { + realMean(i) = currMean(i) * nnz(i) / totalCnt + i += 1 + } + Vectors.fromBreeze(realMean) + } + + override def variance: Vector = { + val realVariance = BDV.zeros[Double](n) + + val denominator = totalCnt - 1.0 + + // Sample variance is computed, if the denominator is less than 0, the variance is just 0. + if (denominator > 0.0) { + val deltaMean = currMean + var i = 0 + while (i < currM2n.size) { + realVariance(i) = + currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt + realVariance(i) /= denominator + i += 1 + } + } + + Vectors.fromBreeze(realVariance) + } + + override def count: Long = totalCnt.toLong + + override def numNonzeros: Vector = Vectors.fromBreeze(nnz) + + override def max: Vector = { + var i = 0 + while (i < n) { + if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0 + i += 1 + } + Vectors.fromBreeze(currMax) + } + + override def min: Vector = { + var i = 0 + while (i < n) { + if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0 + i += 1 + } + Vectors.fromBreeze(currMin) + } + + /** + * Aggregates a row. + */ + def add(currData: BV[Double]): this.type = { + currData.activeIterator.foreach { + case (_, 0.0) => // Skip explicit zero elements. + case (i, value) => + if (currMax(i) < value) { + currMax(i) = value + } + if (currMin(i) > value) { + currMin(i) = value + } + + val tmpPrevMean = currMean(i) + currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) + currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) + + nnz(i) += 1.0 + } + + totalCnt += 1.0 + this + } + + /** + * Merges another aggregator. + */ + def merge(other: ColumnStatisticsAggregator): this.type = { + require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.") + + totalCnt += other.totalCnt + val deltaMean = currMean - other.currMean + + var i = 0 + while (i < n) { + // merge mean together + if (other.currMean(i) != 0.0) { + currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) / + (nnz(i) + other.nnz(i)) + } + // merge m2n together + if (nnz(i) + other.nnz(i) != 0.0) { + currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / + (nnz(i) + other.nnz(i)) + } + if (currMax(i) < other.currMax(i)) { + currMax(i) = other.currMax(i) + } + if (currMin(i) > other.currMin(i)) { + currMin(i) = other.currMin(i) + } + i += 1 + } + + nnz += other.nnz + this + } +} + +/** + * :: Experimental :: + * Represents a row-oriented distributed Matrix with no meaningful row indices. + * + * @param rows rows stored as an RDD[Vector] + * @param nRows number of rows. A non-positive value means unknown, and then the number of rows will + * be determined by the number of records in the RDD `rows`. + * @param nCols number of columns. A non-positive value means unknown, and then the number of + * columns will be determined by the size of the first row. + */ +@Experimental +class RowMatrix( + val rows: RDD[Vector], + private var nRows: Long, + private var nCols: Int) extends DistributedMatrix with Logging { + + /** Alternative constructor leaving matrix dimensions to be determined automatically. */ + def this(rows: RDD[Vector]) = this(rows, 0L, 0) + + /** Gets or computes the number of columns. */ + override def numCols(): Long = { + if (nCols <= 0) { + // Calling `first` will throw an exception if `rows` is empty. + nCols = rows.first().size + } + nCols + } + + /** Gets or computes the number of rows. */ + override def numRows(): Long = { + if (nRows <= 0L) { + nRows = rows.count() + if (nRows == 0L) { + sys.error("Cannot determine the number of rows because it is not specified in the " + + "constructor and the rows RDD is empty.") + } + } + nRows + } + + /** + * Computes the Gramian matrix `A^T A`. + */ + def computeGramianMatrix(): Matrix = { + val n = numCols().toInt + val nt: Int = n * (n + 1) / 2 + + // Compute the upper triangular part of the gram matrix. + val GU = rows.aggregate(new BDV[Double](new Array[Double](nt)))( + seqOp = (U, v) => { + RowMatrix.dspr(1.0, v, U.data) + U + }, + combOp = (U1, U2) => U1 += U2 + ) + + RowMatrix.triuToFull(n, GU.data) + } + + /** + * Computes the singular value decomposition of this matrix. + * Denote this matrix by A (m x n), this will compute matrices U, S, V such that A = U * S * V'. + * + * There is no restriction on m, but we require `n^2` doubles to fit in memory. + * Further, n should be less than m. + + * The decomposition is computed by first computing A'A = V S^2 V', + * computing svd locally on that (since n x n is small), from which we recover S and V. + * Then we compute U via easy matrix multiplication as U = A * (V * S^-1). + * Note that this approach requires `O(n^3)` time on the master node. + * + * At most k largest non-zero singular values and associated vectors are returned. + * If there are k such values, then the dimensions of the return will be: + * + * U is a RowMatrix of size m x k that satisfies U'U = eye(k), + * s is a Vector of size k, holding the singular values in descending order, + * and V is a Matrix of size n x k that satisfies V'V = eye(k). + * + * @param k number of singular values to keep. We might return less than k if there are + * numerically zero singular values. See rCond. + * @param computeU whether to compute U + * @param rCond the reciprocal condition number. All singular values smaller than rCond * sigma(0) + * are treated as zero, where sigma(0) is the largest singular value. + * @return SingularValueDecomposition(U, s, V) + */ + def computeSVD( + k: Int, + computeU: Boolean = false, + rCond: Double = 1e-9): SingularValueDecomposition[RowMatrix, Matrix] = { + val n = numCols().toInt + require(k > 0 && k <= n, s"Request up to n singular values k=$k n=$n.") + + val G = computeGramianMatrix() + + // TODO: Use sparse SVD instead. + val (u: BDM[Double], sigmaSquares: BDV[Double], v: BDM[Double]) = + brzSvd(G.toBreeze.asInstanceOf[BDM[Double]]) + val sigmas: BDV[Double] = brzSqrt(sigmaSquares) + + // Determine effective rank. + val sigma0 = sigmas(0) + val threshold = rCond * sigma0 + var i = 0 + while (i < k && sigmas(i) >= threshold) { + i += 1 + } + val sk = i + + if (sk < k) { + logWarning(s"Requested $k singular values but only found $sk nonzeros.") + } + + val s = Vectors.dense(util.Arrays.copyOfRange(sigmas.data, 0, sk)) + val V = Matrices.dense(n, sk, util.Arrays.copyOfRange(u.data, 0, n * sk)) + + if (computeU) { + // N = Vk * Sk^{-1} + val N = new BDM[Double](n, sk, util.Arrays.copyOfRange(u.data, 0, n * sk)) + var i = 0 + var j = 0 + while (j < sk) { + i = 0 + val sigma = sigmas(j) + while (i < n) { + N(i, j) /= sigma + i += 1 + } + j += 1 + } + val U = this.multiply(Matrices.fromBreeze(N)) + SingularValueDecomposition(U, s, V) + } else { + SingularValueDecomposition(null, s, V) + } + } + + /** + * Computes the covariance matrix, treating each row as an observation. + * @return a local dense matrix of size n x n + */ + def computeCovariance(): Matrix = { + val n = numCols().toInt + + if (n > 10000) { + val mem = n * n * java.lang.Double.SIZE / java.lang.Byte.SIZE + logWarning(s"The number of columns $n is greater than 10000! " + + s"We need at least $mem bytes of memory.") + } + + val (m, mean) = rows.aggregate[(Long, BDV[Double])]((0L, BDV.zeros[Double](n)))( + seqOp = (s: (Long, BDV[Double]), v: Vector) => (s._1 + 1L, s._2 += v.toBreeze), + combOp = (s1: (Long, BDV[Double]), s2: (Long, BDV[Double])) => (s1._1 + s2._1, s1._2 += s2._2) + ) + + updateNumRows(m) + + mean :/= m.toDouble + + // We use the formula Cov(X, Y) = E[X * Y] - E[X] E[Y], which is not accurate if E[X * Y] is + // large but Cov(X, Y) is small, but it is good for sparse computation. + // TODO: find a fast and stable way for sparse data. + + val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] + + var i = 0 + var j = 0 + val m1 = m - 1.0 + var alpha = 0.0 + while (i < n) { + alpha = m / m1 * mean(i) + j = 0 + while (j < n) { + G(i, j) = G(i, j) / m1 - alpha * mean(j) + j += 1 + } + i += 1 + } + + Matrices.fromBreeze(G) + } + + /** + * Computes the top k principal components. + * Rows correspond to observations and columns correspond to variables. + * The principal components are stored a local matrix of size n-by-k. + * Each column corresponds for one principal component, + * and the columns are in descending order of component variance. + * + * @param k number of top principal components. + * @return a matrix of size n-by-k, whose columns are principal components + */ + def computePrincipalComponents(k: Int): Matrix = { + val n = numCols().toInt + require(k > 0 && k <= n, s"k = $k out of range (0, n = $n]") + + val Cov = computeCovariance().toBreeze.asInstanceOf[BDM[Double]] + + val (u: BDM[Double], _, _) = brzSvd(Cov) + + if (k == n) { + Matrices.dense(n, k, u.data) + } else { + Matrices.dense(n, k, util.Arrays.copyOfRange(u.data, 0, n * k)) + } + } + + /** + * Computes column-wise summary statistics. + */ + def computeColumnSummaryStatistics(): MultivariateStatisticalSummary = { + val zeroValue = new ColumnStatisticsAggregator(numCols().toInt) + val summary = rows.map(_.toBreeze).aggregate[ColumnStatisticsAggregator](zeroValue)( + (aggregator, data) => aggregator.add(data), + (aggregator1, aggregator2) => aggregator1.merge(aggregator2) + ) + updateNumRows(summary.count) + summary + } + + /** + * Multiply this matrix by a local matrix on the right. + * + * @param B a local matrix whose number of rows must match the number of columns of this matrix + * @return a [[org.apache.spark.mllib.linalg.distributed.RowMatrix]] representing the product, + * which preserves partitioning + */ + def multiply(B: Matrix): RowMatrix = { + val n = numCols().toInt + require(n == B.numRows, s"Dimension mismatch: $n vs ${B.numRows}") + + require(B.isInstanceOf[DenseMatrix], + s"Only support dense matrix at this time but found ${B.getClass.getName}.") + + val Bb = rows.context.broadcast(B) + val AB = rows.mapPartitions({ iter => + val Bi = Bb.value.toBreeze.asInstanceOf[BDM[Double]] + iter.map(v => Vectors.fromBreeze(Bi.t * v.toBreeze)) + }, preservesPartitioning = true) + + new RowMatrix(AB, nRows, B.numCols) + } + + private[mllib] override def toBreeze(): BDM[Double] = { + val m = numRows().toInt + val n = numCols().toInt + val mat = BDM.zeros[Double](m, n) + var i = 0 + rows.collect().foreach { v => + v.toBreeze.activeIterator.foreach { case (j, v) => + mat(i, j) = v + } + i += 1 + } + mat + } + + /** Updates or verfires the number of rows. */ + private def updateNumRows(m: Long) { + if (nRows <= 0) { + nRows == m + } else { + require(nRows == m, + s"The number of rows $m is different from what specified or previously computed: ${nRows}.") + } + } +} + +object RowMatrix { + + /** + * Adds alpha * x * x.t to a matrix in-place. This is the same as BLAS's DSPR. + * + * @param U the upper triangular part of the matrix packed in an array (column major) + */ + private def dspr(alpha: Double, v: Vector, U: Array[Double]): Unit = { + // TODO: Find a better home (breeze?) for this method. + val n = v.size + v match { + case dv: DenseVector => + blas.dspr("U", n, 1.0, dv.values, 1, U) + case sv: SparseVector => + val indices = sv.indices + val values = sv.values + val nnz = indices.length + var colStartIdx = 0 + var prevCol = 0 + var col = 0 + var j = 0 + var i = 0 + var av = 0.0 + while (j < nnz) { + col = indices(j) + // Skip empty columns. + colStartIdx += (col - prevCol) * (col + prevCol + 1) / 2 + col = indices(j) + av = alpha * values(j) + i = 0 + while (i <= j) { + U(colStartIdx + indices(i)) += av * values(i) + i += 1 + } + j += 1 + prevCol = col + } + } + } + + /** + * Fills a full square matrix from its upper triangular part. + */ + private def triuToFull(n: Int, U: Array[Double]): Matrix = { + val G = new BDM[Double](n, n) + + var row = 0 + var col = 0 + var idx = 0 + var value = 0.0 + while (col < n) { + row = 0 + while (row < col) { + value = U(idx) + G(row, col) = value + G(col, row) = value + idx += 1 + row += 1 + } + G(col, col) = U(idx) + idx += 1 + col +=1 + } + + Matrices.dense(n, n, G.data) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala index 82124703da6cd..679842f831c2a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala @@ -17,39 +17,55 @@ package org.apache.spark.mllib.optimization -import org.jblas.DoubleMatrix +import breeze.linalg.{axpy => brzAxpy} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** + * :: DeveloperApi :: * Class used to compute the gradient for a loss function, given a single data point. */ +@DeveloperApi abstract class Gradient extends Serializable { /** * Compute the gradient and loss given the features of a single data point. * - * @param data - Feature values for one data point. Column matrix of size dx1 - * where d is the number of features. - * @param label - Label for this data item. - * @param weights - Column matrix containing weights for every feature. + * @param data features for one data point + * @param label label for this data point + * @param weights weights/coefficients corresponding to features * - * @return A tuple of 2 elements. The first element is a column matrix containing the computed - * gradient and the second element is the loss computed at this data point. + * @return (gradient: Vector, loss: Double) + */ + def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) + + /** + * Compute the gradient and loss given the features of a single data point, + * add the gradient to a provided vector to avoid creating new objects, and return loss. * + * @param data features for one data point + * @param label label for this data point + * @param weights weights/coefficients corresponding to features + * @param cumGradient the computed gradient will be added to this vector + * + * @return loss */ - def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) + def compute(data: Vector, label: Double, weights: Vector, cumGradient: Vector): Double } /** + * :: DeveloperApi :: * Compute gradient and loss for a logistic loss function, as used in binary classification. * See also the documentation for the precise formulation. */ +@DeveloperApi class LogisticGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val margin: Double = -1.0 * data.dot(weights) + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val margin: Double = -1.0 * brzWeights.dot(brzData) val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label - - val gradient = data.mul(gradientMultiplier) + val gradient = brzData * gradientMultiplier val loss = if (label > 0) { math.log(1 + math.exp(margin)) @@ -57,47 +73,105 @@ class LogisticGradient extends Gradient { math.log(1 + math.exp(margin)) - margin } - (gradient, loss) + (Vectors.fromBreeze(gradient), loss) + } + + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val margin: Double = -1.0 * brzWeights.dot(brzData) + val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label + + brzAxpy(gradientMultiplier, brzData, cumGradient.toBreeze) + + if (label > 0) { + math.log(1 + math.exp(margin)) + } else { + math.log(1 + math.exp(margin)) - margin + } } } /** + * :: DeveloperApi :: * Compute gradient and loss for a Least-squared loss function, as used in linear regression. * This is correct for the averaged least squares loss function (mean squared error) * L = 1/n ||A weights-y||^2 * See also the documentation for the precise formulation. */ +@DeveloperApi class LeastSquaresGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { - val diff: Double = data.dot(weights) - label - + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val diff = brzWeights.dot(brzData) - label val loss = diff * diff - val gradient = data.mul(2.0 * diff) + val gradient = brzData * (2.0 * diff) - (gradient, loss) + (Vectors.fromBreeze(gradient), loss) + } + + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val diff = brzWeights.dot(brzData) - label + + brzAxpy(2.0 * diff, brzData, cumGradient.toBreeze) + + diff * diff } } /** + * :: DeveloperApi :: * Compute gradient and loss for a Hinge loss function, as used in SVM binary classification. * See also the documentation for the precise formulation. * NOTE: This assumes that the labels are {0,1} */ +@DeveloperApi class HingeGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): - (DoubleMatrix, Double) = { + override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val dotProduct = brzWeights.dot(brzData) + + // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Therefore the gradient is -(2y - 1)*x + val labelScaled = 2 * label - 1.0 + + if (1.0 > labelScaled * dotProduct) { + (Vectors.fromBreeze(brzData * (-labelScaled)), 1.0 - labelScaled * dotProduct) + } else { + (Vectors.dense(new Array[Double](weights.size)), 0.0) + } + } - val dotProduct = data.dot(weights) + override def compute( + data: Vector, + label: Double, + weights: Vector, + cumGradient: Vector): Double = { + val brzData = data.toBreeze + val brzWeights = weights.toBreeze + val dotProduct = brzWeights.dot(brzData) // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x val labelScaled = 2 * label - 1.0 if (1.0 > labelScaled * dotProduct) { - (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct) + brzAxpy(-labelScaled, brzData, cumGradient.toBreeze) + 1.0 - labelScaled * dotProduct } else { - (DoubleMatrix.zeros(1, weights.length), 0.0) + 0.0 } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index b967b22e818d3..f60417f21d4b9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -17,19 +17,23 @@ package org.apache.spark.mllib.optimization -import org.apache.spark.Logging -import org.apache.spark.rdd.RDD +import scala.collection.mutable.ArrayBuffer -import org.jblas.DoubleMatrix +import breeze.linalg.{DenseVector => BDV} -import scala.collection.mutable.ArrayBuffer +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** + * :: DeveloperApi :: * Class used to solve an optimization problem using Gradient Descent. * @param gradient Gradient function to be used. * @param updater Updater to be used to update weights after every iteration. */ -class GradientDescent(var gradient: Gradient, var updater: Updater) +@DeveloperApi +class GradientDescent(private var gradient: Gradient, private var updater: Updater) extends Optimizer with Logging { private var stepSize: Double = 1.0 @@ -91,24 +95,26 @@ class GradientDescent(var gradient: Gradient, var updater: Updater) this } - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]) - : Array[Double] = { - - val (weights, stochasticLossHistory) = GradientDescent.runMiniBatchSGD( - data, - gradient, - updater, - stepSize, - numIterations, - regParam, - miniBatchFraction, - initialWeights) + def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = { + val (weights, _) = GradientDescent.runMiniBatchSGD( + data, + gradient, + updater, + stepSize, + numIterations, + regParam, + miniBatchFraction, + initialWeights) weights } } -// Top-level method to run gradient descent. +/** + * :: DeveloperApi :: + * Top-level method to run gradient descent. + */ +@DeveloperApi object GradientDescent extends Logging { /** * Run stochastic gradient descent (SGD) in parallel using mini batches. @@ -133,14 +139,14 @@ object GradientDescent extends Logging { * stochastic loss computed for every iteration. */ def runMiniBatchSGD( - data: RDD[(Double, Array[Double])], + data: RDD[(Double, Vector)], gradient: Gradient, updater: Updater, stepSize: Double, numIterations: Int, regParam: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) : (Array[Double], Array[Double]) = { + initialWeights: Vector): (Vector, Array[Double]) = { val stochasticLossHistory = new ArrayBuffer[Double](numIterations) @@ -148,24 +154,27 @@ object GradientDescent extends Logging { val miniBatchSize = nexamples * miniBatchFraction // Initialize weights as a column vector - var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) + var weights = Vectors.dense(initialWeights.toArray) /** * For the first iteration, the regVal will be initialized as sum of sqrt of * weights if it's L2 update; for L1 update; the same logic is followed. */ var regVal = updater.compute( - weights, new DoubleMatrix(initialWeights.length, 1), 0, 1, regParam)._2 + weights, Vectors.dense(new Array[Double](weights.size)), 0, 1, regParam)._2 for (i <- 1 to numIterations) { // Sample a subset (fraction miniBatchFraction) of the total data // compute and sum up the subgradients on this subset (this is one map-reduce) - val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i).map { - case (y, features) => - val featuresCol = new DoubleMatrix(features.length, 1, features:_*) - val (grad, loss) = gradient.compute(featuresCol, y, weights) - (grad, loss) - }.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2)) + val (gradientSum, lossSum) = data.sample(false, miniBatchFraction, 42 + i) + .aggregate((BDV.zeros[Double](weights.size), 0.0))( + seqOp = (c, v) => (c, v) match { case ((grad, loss), (label, features)) => + val l = gradient.compute(features, label, weights, Vectors.fromBreeze(grad)) + (grad, loss + l) + }, + combOp = (c1, c2) => (c1, c2) match { case ((grad1, loss1), (grad2, loss2)) => + (grad1 += grad2, loss1 + loss2) + }) /** * NOTE(Xinghao): lossSum is computed using the weights from the previous iteration @@ -173,7 +182,7 @@ object GradientDescent extends Logging { */ stochasticLossHistory.append(lossSum / miniBatchSize + regVal) val update = updater.compute( - weights, gradientSum.div(miniBatchSize), stepSize, i, regParam) + weights, Vectors.fromBreeze(gradientSum / miniBatchSize), stepSize, i, regParam) weights = update._1 regVal = update._2 } @@ -181,6 +190,6 @@ object GradientDescent extends Logging { logInfo("GradientDescent.runMiniBatchSGD finished. Last 10 stochastic losses %s".format( stochasticLossHistory.takeRight(10).mkString(", "))) - (weights.toArray, stochasticLossHistory.toArray) + (weights, stochasticLossHistory.toArray) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala index 94d30b56f212b..7f6d94571b5ef 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -19,11 +19,18 @@ package org.apache.spark.mllib.optimization import org.apache.spark.rdd.RDD -trait Optimizer { +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.Vector + +/** + * :: DeveloperApi :: + * Trait for optimization problem solvers. + */ +@DeveloperApi +trait Optimizer extends Serializable { /** - * Solve the provided convex optimization problem. + * Solve the provided convex optimization problem. */ - def optimize(data: RDD[(Double, Array[Double])], initialWeights: Array[Double]): Array[Double] - + def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index bf8f731459e99..3ed3a5b9b3843 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -18,9 +18,14 @@ package org.apache.spark.mllib.optimization import scala.math._ -import org.jblas.DoubleMatrix + +import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** + * :: DeveloperApi :: * Class used to perform steps (weight update) using Gradient Descent methods. * * For general minimization problems, or for regularized problems of the form @@ -32,6 +37,7 @@ import org.jblas.DoubleMatrix * The updater is responsible to also perform the update coming from the * regularization term R(w) (if any regularization is used). */ +@DeveloperApi abstract class Updater extends Serializable { /** * Compute an updated value for weights given the gradient, stepSize, iteration number and @@ -47,24 +53,37 @@ abstract class Updater extends Serializable { * @return A tuple of 2 elements. The first element is a column matrix containing updated weights, * and the second element is the regularization value computed using updated weights. */ - def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, - regParam: Double): (DoubleMatrix, Double) + def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) } /** + * :: DeveloperApi :: * A simple updater for gradient descent *without* any regularization. * Uses a step-size decreasing with the square root of the number of iterations. */ +@DeveloperApi class SimpleUpdater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) - val step = gradient.mul(thisIterStepSize) - (weightsOld.sub(step), 0) + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + + (Vectors.fromBreeze(brzWeights), 0) } } /** + * :: DeveloperApi :: * Updater for L1 regularized problems. * R(w) = ||w||_1 * Uses a step-size decreasing with the square root of the number of iterations. @@ -82,39 +101,56 @@ class SimpleUpdater extends Updater { * * Equivalently, set weight component to signum(w) * max(0.0, abs(w) - shrinkageVal) */ +@DeveloperApi class L1Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { val thisIterStepSize = stepSize / math.sqrt(iter) - val step = gradient.mul(thisIterStepSize) // Take gradient step - val newWeights = weightsOld.sub(step) + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) // Apply proximal operator (soft thresholding) val shrinkageVal = regParam * thisIterStepSize - (0 until newWeights.length).foreach { i => - val wi = newWeights.get(i) - newWeights.put(i, signum(wi) * max(0.0, abs(wi) - shrinkageVal)) + var i = 0 + while (i < brzWeights.length) { + val wi = brzWeights(i) + brzWeights(i) = signum(wi) * max(0.0, abs(wi) - shrinkageVal) + i += 1 } - (newWeights, newWeights.norm1 * regParam) + + (Vectors.fromBreeze(brzWeights), brzNorm(brzWeights, 1.0) * regParam) } } /** + * :: DeveloperApi :: * Updater for L2 regularized problems. * R(w) = 1/2 ||w||^2 * Uses a step-size decreasing with the square root of the number of iterations. */ +@DeveloperApi class SquaredL2Updater extends Updater { - override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, - stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = { - val thisIterStepSize = stepSize / math.sqrt(iter) - val step = gradient.mul(thisIterStepSize) + override def compute( + weightsOld: Vector, + gradient: Vector, + stepSize: Double, + iter: Int, + regParam: Double): (Vector, Double) = { // add up both updates from the gradient of the loss (= step) as well as // the gradient of the regularizer (= regParam * weightsOld) // w' = w - thisIterStepSize * (gradient + regParam * w) // w' = (1 - thisIterStepSize * regParam) * w - thisIterStepSize * gradient - val newWeights = weightsOld.mul(1.0 - thisIterStepSize * regParam).sub(step) - (newWeights, 0.5 * pow(newWeights.norm2, 2.0) * regParam) + val thisIterStepSize = stepSize / math.sqrt(iter) + val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector + brzWeights :*= (1.0 - thisIterStepSize * regParam) + brzAxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) + val norm = brzNorm(brzWeights, 2.0) + + (Vectors.fromBreeze(brzWeights), 0.5 * regParam * norm * norm) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala new file mode 100644 index 0000000000000..873de871fd884 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD + +/** + * Machine learning specific RDD functions. + */ +private[mllib] +class RDDFunctions[T: ClassTag](self: RDD[T]) { + + /** + * Returns a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * window over them. The ordering is first based on the partition index and then the ordering of + * items within each partition. This is similar to sliding in Scala collections, except that it + * becomes an empty RDD if the window size is greater than the total number of items. It needs to + * trigger a Spark job if the parent RDD has more than one partitions and the window size is + * greater than 1. + */ + def sliding(windowSize: Int): RDD[Seq[T]] = { + require(windowSize > 0, s"Sliding window size must be positive, but got $windowSize.") + if (windowSize == 1) { + self.map(Seq(_)) + } else { + new SlidingRDD[T](self, windowSize) + } + } +} + +private[mllib] +object RDDFunctions { + + /** Implicit conversion from an RDD to RDDFunctions. */ + implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd) +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala new file mode 100644 index 0000000000000..dd80782c0f001 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.collection.mutable +import scala.reflect.ClassTag + +import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.rdd.RDD + +private[mllib] +class SlidingRDDPartition[T](val idx: Int, val prev: Partition, val tail: Seq[T]) + extends Partition with Serializable { + override val index: Int = idx +} + +/** + * Represents a RDD from grouping items of its parent RDD in fixed size blocks by passing a sliding + * window over them. The ordering is first based on the partition index and then the ordering of + * items within each partition. This is similar to sliding in Scala collections, except that it + * becomes an empty RDD if the window size is greater than the total number of items. It needs to + * trigger a Spark job if the parent RDD has more than one partitions. To make this operation + * efficient, the number of items per partition should be larger than the window size and the + * window size should be small, e.g., 2. + * + * @param parent the parent RDD + * @param windowSize the window size, must be greater than 1 + * + * @see [[org.apache.spark.mllib.rdd.RDDFunctions#sliding]] + */ +private[mllib] +class SlidingRDD[T: ClassTag](@transient val parent: RDD[T], val windowSize: Int) + extends RDD[Seq[T]](parent) { + + require(windowSize > 1, s"Window size must be greater than 1, but got $windowSize.") + + override def compute(split: Partition, context: TaskContext): Iterator[Seq[T]] = { + val part = split.asInstanceOf[SlidingRDDPartition[T]] + (firstParent[T].iterator(part.prev, context) ++ part.tail) + .sliding(windowSize) + .withPartial(false) + } + + override def getPreferredLocations(split: Partition): Seq[String] = + firstParent[T].preferredLocations(split.asInstanceOf[SlidingRDDPartition[T]].prev) + + override def getPartitions: Array[Partition] = { + val parentPartitions = parent.partitions + val n = parentPartitions.size + if (n == 0) { + Array.empty + } else if (n == 1) { + Array(new SlidingRDDPartition[T](0, parentPartitions(0), Seq.empty)) + } else { + val n1 = n - 1 + val w1 = windowSize - 1 + // Get the first w1 items of each partition, starting from the second partition. + val nextHeads = + parent.context.runJob(parent, (iter: Iterator[T]) => iter.take(w1).toArray, 1 until n, true) + val partitions = mutable.ArrayBuffer[SlidingRDDPartition[T]]() + var i = 0 + var partitionIndex = 0 + while (i < n1) { + var j = i + val tail = mutable.ListBuffer[T]() + // Keep appending to the current tail until appended a head of size w1. + while (j < n1 && nextHeads(j).size < w1) { + tail ++= nextHeads(j) + j += 1 + } + if (j < n1) { + tail ++= nextHeads(j) + j += 1 + } + partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(i), tail) + partitionIndex += 1 + // Skip appended heads. + i = j + } + // If the head of last partition has size w1, we also need to add this partition. + if (nextHeads.last.size == w1) { + partitions += new SlidingRDDPartition[T](partitionIndex, parentPartitions(n1), Seq.empty) + } + partitions.toArray + } + } + + // TODO: Override methods such as aggregate, which only requires one Spark job. +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 0cc9f48769f83..5cc47de8ffdfc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -22,6 +22,10 @@ import scala.math.{abs, sqrt} import scala.util.Random import scala.util.Sorting +import com.esotericsoftware.kryo.Kryo +import org.jblas.{DoubleMatrix, SimpleBlas, Solve} + +import org.apache.spark.annotation.Experimental import org.apache.spark.broadcast.Broadcast import org.apache.spark.{Logging, HashPartitioner, Partitioner, SparkContext, SparkConf} import org.apache.spark.storage.StorageLevel @@ -29,10 +33,6 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.SparkContext._ -import com.esotericsoftware.kryo.Kryo -import org.jblas.{DoubleMatrix, SimpleBlas, Solve} - - /** * Out-link information for a user or product block. This includes the original user/product IDs * of the elements within this block, and the list of destination blocks that each user or @@ -90,14 +90,19 @@ case class Rating(val user: Int, val product: Int, val rating: Double) * preferences rather than explicit ratings given to items. */ class ALS private ( - var numBlocks: Int, - var rank: Int, - var iterations: Int, - var lambda: Double, - var implicitPrefs: Boolean, - var alpha: Double, - var seed: Long = System.nanoTime() + private var numBlocks: Int, + private var rank: Int, + private var iterations: Int, + private var lambda: Double, + private var implicitPrefs: Boolean, + private var alpha: Double, + private var seed: Long = System.nanoTime() ) extends Serializable with Logging { + + /** + * Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10, + * lambda: 0.01, implicitPrefs: false, alpha: 1.0}. + */ def this() = this(-1, 10, 10, 0.01, false, 1.0) /** @@ -127,11 +132,17 @@ class ALS private ( this } + /** Sets whether to use implicit preference. Default: false. */ def setImplicitPrefs(implicitPrefs: Boolean): ALS = { this.implicitPrefs = implicitPrefs this } + /** + * :: Experimental :: + * Sets the constant used in computing confidence in implicit ALS. Default: 1.0. + */ + @Experimental def setAlpha(alpha: Double): ALS = { this.alpha = alpha this @@ -421,12 +432,12 @@ class ALS private ( * Compute the new feature vectors for a block of the users matrix given the list of factors * it received from each product and its InLinkBlock. */ - private def updateBlock(messages: Seq[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, + private def updateBlock(messages: Iterable[(Int, Array[Array[Double]])], inLinkBlock: InLinkBlock, rank: Int, lambda: Double, alpha: Double, YtY: Option[Broadcast[DoubleMatrix]]) : Array[Array[Double]] = { // Sort the incoming block factor messages by block ID and make them an array - val blockFactors = messages.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] + val blockFactors = messages.toSeq.sortBy(_._1).map(_._2).toArray // Array[Array[Double]] val numBlocks = blockFactors.length val numUsers = inLinkBlock.elementIds.length diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala index 443fc5de5bf04..471546cd82c7d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala @@ -17,13 +17,14 @@ package org.apache.spark.mllib.recommendation +import org.jblas._ + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ import org.apache.spark.mllib.api.python.PythonMLLibAPI -import org.jblas._ -import org.apache.spark.api.java.JavaRDD - /** * Model representing the result of matrix factorization. @@ -68,6 +69,7 @@ class MatrixFactorizationModel( } /** + * :: DeveloperApi :: * Predict the rating of many users for many products. * This is a Java stub for python predictAll() * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 3e1ed91bf6729..d969e7aa60061 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,35 +17,33 @@ package org.apache.spark.mllib.regression +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} + +import org.apache.spark.annotation.Experimental import org.apache.spark.{Logging, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.linalg.{Vectors, Vector} /** - * GeneralizedLinearModel (GLM) represents a model trained using + * GeneralizedLinearModel (GLM) represents a model trained using * GeneralizedLinearAlgorithm. GLMs consist of a weight vector and * an intercept. * * @param weights Weights computed for every feature. * @param intercept Intercept computed for this model. */ -abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: Double) +abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double) extends Serializable { - // Create a column vector that can be used for predictions - private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) - /** * Predict the result given a data point and the weights learned. - * + * * @param dataMatrix Row vector containing the features for this data point * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. */ - def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double): Double + protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double /** * Predict values for the given data set using the model trained. @@ -53,16 +51,13 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Array[Double]]): RDD[Double] = { + def predict(testData: RDD[Vector]): RDD[Double] = { // A small optimization to avoid serializing the entire model. Only the weightsMatrix // and intercept is needed. - val localWeights = weightsMatrix + val localWeights = weights val localIntercept = intercept - testData.map { x => - val dataMatrix = new DoubleMatrix(1, x.length, x:_*) - predictPoint(dataMatrix, localWeights, localIntercept) - } + testData.map(v => predictPoint(v, localWeights, localIntercept)) } /** @@ -71,14 +66,13 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: * @param testData array representing a single data point * @return Double prediction from the trained model */ - def predict(testData: Array[Double]): Double = { - val dataMat = new DoubleMatrix(1, testData.length, testData:_*) - predictPoint(dataMat, weightsMatrix, intercept) + def predict(testData: Vector): Double = { + predictPoint(testData, weights, intercept) } } /** - * GeneralizedLinearAlgorithm implements methods to train a Genearalized Linear Model (GLM). + * GeneralizedLinearAlgorithm implements methods to train a Generalized Linear Model (GLM). * This class should be extended with an Optimizer to create a new GLM. */ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] @@ -86,8 +80,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List() - val optimizer: Optimizer + /** The optimizer to solve the problem. */ + def optimizer: Optimizer + /** Whether to add intercept (default: true). */ protected var addIntercept: Boolean = true protected var validateData: Boolean = true @@ -95,7 +91,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] /** * Create a model given the weights and intercept */ - protected def createModel(weights: Array[Double], intercept: Double): M + protected def createModel(weights: Vector, intercept: Double): M /** * Set if the algorithm should add an intercept. Default true. @@ -106,8 +102,10 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] } /** + * :: Experimental :: * Set if the algorithm should validate data before training. Default true. */ + @Experimental def setValidateData(validateData: Boolean): this.type = { this.validateData = validateData this @@ -117,17 +115,27 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. */ - def run(input: RDD[LabeledPoint]) : M = { - val nfeatures: Int = input.first().features.length - val initialWeights = new Array[Double](nfeatures) + def run(input: RDD[LabeledPoint]): M = { + val numFeatures: Int = input.first().features.size + val initialWeights = Vectors.dense(new Array[Double](numFeatures)) run(input, initialWeights) } + /** Prepends one to the input vector. */ + private def prependOne(vector: Vector): Vector = { + val vector1 = vector.toBreeze match { + case dv: BDV[Double] => BDV.vertcat(BDV.ones[Double](1), dv) + case sv: BSV[Double] => BSV.vertcat(new BSV[Double](Array(0), Array(1.0), 1), sv) + case v: Any => throw new IllegalArgumentException("Do not support vector type " + v.getClass) + } + Vectors.fromBreeze(vector1) + } + /** * Run the algorithm with the configured parameters on an input RDD * of LabeledPoint entries starting from the initial weights provided. */ - def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { + def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { // Check the data properties before running the optimizer if (validateData && !validators.forall(func => func(input))) { @@ -136,27 +144,26 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] // Prepend an extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { - input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features)) + input.map(labeledPoint => (labeledPoint.label, prependOne(labeledPoint.features))) } else { input.map(labeledPoint => (labeledPoint.label, labeledPoint.features)) } val initialWeightsWithIntercept = if (addIntercept) { - 0.0 +: initialWeights + prependOne(initialWeights) } else { initialWeights } val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept) - val (intercept, weights) = if (addIntercept) { - (weightsWithIntercept(0), weightsWithIntercept.tail) - } else { - (0.0, weightsWithIntercept) - } - - logInfo("Final weights " + weights.mkString(",")) - logInfo("Final intercept " + intercept) + val intercept = if (addIntercept) weightsWithIntercept(0) else 0.0 + val weights = + if (addIntercept) { + Vectors.dense(weightsWithIntercept.toArray.slice(1, weightsWithIntercept.size)) + } else { + weightsWithIntercept + } createModel(weights, intercept) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index 1a18292fe3f3b..3deab1ab785b9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -17,14 +17,16 @@ package org.apache.spark.mllib.regression +import org.apache.spark.mllib.linalg.Vector + /** * Class that represents the features and labels of a data point. * * @param label Label for this data point. * @param features List of features for this data point. */ -case class LabeledPoint(label: Double, features: Array[Double]) { +case class LabeledPoint(label: Double, features: Vector) { override def toString: String = { - "LabeledPoint(%s, %s)".format(label, features.mkString("[", ", ", "]")) + "LabeledPoint(%s, %s)".format(label, features) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index be63ce8538fef..5f0812fd2e0eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -17,12 +17,11 @@ package org.apache.spark.mllib.regression -import org.apache.spark.{Logging, SparkContext} -import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix +import org.apache.spark.rdd.RDD /** * Regression model trained using Lasso. @@ -31,16 +30,16 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LassoModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint( - dataMatrix: DoubleMatrix, - weightMatrix: DoubleMatrix, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double): Double = { - dataMatrix.dot(weightMatrix) + intercept + weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } } @@ -53,16 +52,16 @@ class LassoModel( * See also the documentation for the precise formulation. */ class LassoWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LassoModel] - with Serializable { - - val gradient = new LeastSquaresGradient() - val updater = new L1Updater() - @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) + private var stepSize: Double, + private var numIterations: Int, + private var regParam: Double, + private var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[LassoModel] with Serializable { + + private val gradient = new LeastSquaresGradient() + private val updater = new L1Updater() + override val optimizer = new GradientDescent(gradient, updater) + .setStepSize(stepSize) .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) @@ -70,12 +69,9 @@ class LassoWithSGD private ( // We don't want to penalize the intercept, so set this to false. super.setIntercept(false) - var yMean = 0.0 - var xColMean: DoubleMatrix = _ - var xColSd: DoubleMatrix = _ - /** - * Construct a Lasso object with default parameters + * Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100, + * regParam: 1.0, miniBatchFraction: 1.0}. */ def this() = this(1.0, 100, 1.0, 1.0) @@ -85,36 +81,8 @@ class LassoWithSGD private ( this } - override def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) - val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) - - new LassoModel(weightsScaled.data, interceptScaled) - } - - override def run( - input: RDD[LabeledPoint], - initialWeights: Array[Double]) - : LassoModel = - { - val nfeatures: Int = input.first.features.length - val nexamples: Long = input.count() - - // To avoid penalizing the intercept, we center and scale the data. - val stats = MLUtils.computeStats(input, nfeatures, nexamples) - yMean = stats._1 - xColMean = stats._2 - xColSd = stats._3 - - val normalizedData = input.map { point => - val yNormalized = point.label - yMean - val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) - val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) - LabeledPoint(yNormalized, featuresNormalized.toArray) - } - - super.run(normalizedData, initialWeights) + override protected def createModel(weights: Vector, intercept: Double) = { + new LassoModel(weights, intercept) } } @@ -144,11 +112,9 @@ object LassoWithSGD { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : LassoModel = - { - new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input, - initialWeights) + initialWeights: Vector): LassoModel = { + new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction) + .run(input, initialWeights) } /** @@ -168,9 +134,7 @@ object LassoWithSGD { numIterations: Int, stepSize: Double, regParam: Double, - miniBatchFraction: Double) - : LassoModel = - { + miniBatchFraction: Double): LassoModel = { new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } @@ -190,9 +154,7 @@ object LassoWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - regParam: Double) - : LassoModel = - { + regParam: Double): LassoModel = { train(input, numIterations, stepSize, regParam, 1.0) } @@ -208,9 +170,7 @@ object LassoWithSGD { */ def train( input: RDD[LabeledPoint], - numIterations: Int) - : LassoModel = - { + numIterations: Int): LassoModel = { train(input, numIterations, 1.0, 1.0, 1.0) } @@ -222,7 +182,8 @@ object LassoWithSGD { val sc = new SparkContext(args(0), "Lasso") val data = MLUtils.loadLabeledData(sc, args(1)) val model = LassoWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index f5f15d1a33f4d..228fa8db3e721 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -19,11 +19,10 @@ package org.apache.spark.mllib.regression import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils -import org.jblas.DoubleMatrix - /** * Regression model trained using LinearRegression. * @@ -31,15 +30,15 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LinearRegressionModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint( - dataMatrix: DoubleMatrix, - weightMatrix: DoubleMatrix, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double): Double = { - dataMatrix.dot(weightMatrix) + intercept + weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } } @@ -53,23 +52,25 @@ class LinearRegressionModel( * See also the documentation for the precise formulation. */ class LinearRegressionWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var miniBatchFraction: Double) + private var stepSize: Double, + private var numIterations: Int, + private var miniBatchFraction: Double) extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable { - val gradient = new LeastSquaresGradient() - val updater = new SimpleUpdater() - val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) + private val gradient = new LeastSquaresGradient() + private val updater = new SimpleUpdater() + override val optimizer = new GradientDescent(gradient, updater) + .setStepSize(stepSize) .setNumIterations(numIterations) .setMiniBatchFraction(miniBatchFraction) /** - * Construct a LinearRegression object with default parameters + * Construct a LinearRegression object with default parameters: {stepSize: 1.0, + * numIterations: 100, miniBatchFraction: 1.0}. */ def this() = this(1.0, 100, 1.0) - override def createModel(weights: Array[Double], intercept: Double) = { + override protected def createModel(weights: Vector, intercept: Double) = { new LinearRegressionModel(weights, intercept) } } @@ -98,11 +99,9 @@ object LinearRegressionWithSGD { numIterations: Int, stepSize: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : LinearRegressionModel = - { - new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input, - initialWeights) + initialWeights: Vector): LinearRegressionModel = { + new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction) + .run(input, initialWeights) } /** @@ -120,9 +119,7 @@ object LinearRegressionWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - miniBatchFraction: Double) - : LinearRegressionModel = - { + miniBatchFraction: Double): LinearRegressionModel = { new LinearRegressionWithSGD(stepSize, numIterations, miniBatchFraction).run(input) } @@ -140,9 +137,7 @@ object LinearRegressionWithSGD { def train( input: RDD[LabeledPoint], numIterations: Int, - stepSize: Double) - : LinearRegressionModel = - { + stepSize: Double): LinearRegressionModel = { train(input, numIterations, stepSize, 1.0) } @@ -158,9 +153,7 @@ object LinearRegressionWithSGD { */ def train( input: RDD[LabeledPoint], - numIterations: Int) - : LinearRegressionModel = - { + numIterations: Int): LinearRegressionModel = { train(input, numIterations, 1.0, 1.0) } @@ -172,7 +165,7 @@ object LinearRegressionWithSGD { val sc = new SparkContext(args(0), "LinearRegression") val data = MLUtils.loadLabeledData(sc, args(1)) val model = LinearRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala index 423afc32d665c..5e4b8a345b1c5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.regression import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector trait RegressionModel extends Serializable { /** @@ -26,7 +27,7 @@ trait RegressionModel extends Serializable { * @param testData RDD representing data points to be predicted * @return RDD[Double] where each entry contains the corresponding prediction */ - def predict(testData: RDD[Array[Double]]): RDD[Double] + def predict(testData: RDD[Vector]): RDD[Double] /** * Predict values for a single data point using the model trained. @@ -34,5 +35,5 @@ trait RegressionModel extends Serializable { * @param testData array representing a single data point * @return Double prediction from the trained model */ - def predict(testData: Array[Double]): Double + def predict(testData: Vector): Double } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index feb100f21888f..e702027c7c170 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -21,8 +21,7 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.util.MLUtils - -import org.jblas.DoubleMatrix +import org.apache.spark.mllib.linalg.Vector /** * Regression model trained using RidgeRegression. @@ -31,16 +30,16 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class RidgeRegressionModel( - override val weights: Array[Double], + override val weights: Vector, override val intercept: Double) extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint( - dataMatrix: DoubleMatrix, - weightMatrix: DoubleMatrix, + override protected def predictPoint( + dataMatrix: Vector, + weightMatrix: Vector, intercept: Double): Double = { - dataMatrix.dot(weightMatrix) + intercept + weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept } } @@ -53,17 +52,17 @@ class RidgeRegressionModel( * See also the documentation for the precise formulation. */ class RidgeRegressionWithSGD private ( - var stepSize: Double, - var numIterations: Int, - var regParam: Double, - var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[RidgeRegressionModel] - with Serializable { + private var stepSize: Double, + private var numIterations: Int, + private var regParam: Double, + private var miniBatchFraction: Double) + extends GeneralizedLinearAlgorithm[RidgeRegressionModel] with Serializable { - val gradient = new LeastSquaresGradient() - val updater = new SquaredL2Updater() + private val gradient = new LeastSquaresGradient() + private val updater = new SquaredL2Updater() - @transient val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) + override val optimizer = new GradientDescent(gradient, updater) + .setStepSize(stepSize) .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) @@ -71,12 +70,9 @@ class RidgeRegressionWithSGD private ( // We don't want to penalize the intercept in RidgeRegression, so set this to false. super.setIntercept(false) - var yMean = 0.0 - var xColMean: DoubleMatrix = _ - var xColSd: DoubleMatrix = _ - /** - * Construct a RidgeRegression object with default parameters + * Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100, + * regParam: 1.0, miniBatchFraction: 1.0}. */ def this() = this(1.0, 100, 1.0, 1.0) @@ -86,36 +82,8 @@ class RidgeRegressionWithSGD private ( this } - override def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) - val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) - - new RidgeRegressionModel(weightsScaled.data, interceptScaled) - } - - override def run( - input: RDD[LabeledPoint], - initialWeights: Array[Double]) - : RidgeRegressionModel = - { - val nfeatures: Int = input.first().features.length - val nexamples: Long = input.count() - - // To avoid penalizing the intercept, we center and scale the data. - val stats = MLUtils.computeStats(input, nfeatures, nexamples) - yMean = stats._1 - xColMean = stats._2 - xColSd = stats._3 - - val normalizedData = input.map { point => - val yNormalized = point.label - yMean - val featuresMat = new DoubleMatrix(nfeatures, 1, point.features:_*) - val featuresNormalized = featuresMat.sub(xColMean).divi(xColSd) - LabeledPoint(yNormalized, featuresNormalized.toArray) - } - - super.run(normalizedData, initialWeights) + override protected def createModel(weights: Vector, intercept: Double) = { + new RidgeRegressionModel(weights, intercept) } } @@ -144,9 +112,7 @@ object RidgeRegressionWithSGD { stepSize: Double, regParam: Double, miniBatchFraction: Double, - initialWeights: Array[Double]) - : RidgeRegressionModel = - { + initialWeights: Vector): RidgeRegressionModel = { new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run( input, initialWeights) } @@ -167,9 +133,7 @@ object RidgeRegressionWithSGD { numIterations: Int, stepSize: Double, regParam: Double, - miniBatchFraction: Double) - : RidgeRegressionModel = - { + miniBatchFraction: Double): RidgeRegressionModel = { new RidgeRegressionWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input) } @@ -188,9 +152,7 @@ object RidgeRegressionWithSGD { input: RDD[LabeledPoint], numIterations: Int, stepSize: Double, - regParam: Double) - : RidgeRegressionModel = - { + regParam: Double): RidgeRegressionModel = { train(input, numIterations, stepSize, regParam, 1.0) } @@ -205,23 +167,22 @@ object RidgeRegressionWithSGD { */ def train( input: RDD[LabeledPoint], - numIterations: Int) - : RidgeRegressionModel = - { + numIterations: Int): RidgeRegressionModel = { train(input, numIterations, 1.0, 1.0, 1.0) } def main(args: Array[String]) { if (args.length != 5) { - println("Usage: RidgeRegression " + - " ") + println("Usage: RidgeRegression " + + " ") System.exit(1) } val sc = new SparkContext(args(0), "RidgeRegression") val data = MLUtils.loadLabeledData(sc, args(1)) val model = RidgeRegressionWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble) - println("Weights: " + model.weights.mkString("[", ", ", "]")) + + println("Weights: " + model.weights) println("Intercept: " + model.intercept) sc.stop() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala new file mode 100644 index 0000000000000..f9eb343da2b82 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.stat + +import org.apache.spark.mllib.linalg.Vector + +/** + * Trait for multivariate statistical summary of a data matrix. + */ +trait MultivariateStatisticalSummary { + + /** + * Sample mean vector. + */ + def mean: Vector + + /** + * Sample variance vector. Should return a zero vector if the sample size is 1. + */ + def variance: Vector + + /** + * Sample size. + */ + def count: Long + + /** + * Number of nonzero elements (including explicitly presented zero values) in each column. + */ + def numNonzeros: Vector + + /** + * Maximum value of each column. + */ + def max: Vector + + /** + * Minimum value of each column. + */ + def min: Vector +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala new file mode 100644 index 0000000000000..3019447ce4cd9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -0,0 +1,1154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.util.control.Breaks._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * :: Experimental :: + * A class that implements a decision tree algorithm for classification and regression. It + * supports both continuous and categorical features. + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + */ +@Experimental +class DecisionTree (private val strategy: Strategy) extends Serializable with Logging { + + /** + * Method to train a decision tree model over an RDD + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * @return a DecisionTreeModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint]): DecisionTreeModel = { + + // Cache input RDD for speedup during multiple passes. + input.cache() + logDebug("algo = " + strategy.algo) + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + val (splits, bins) = DecisionTree.findSplitsBins(input, strategy) + logDebug("numSplits = " + bins(0).length) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + // the max number of nodes possible given the depth of the tree + val maxNumNodes = scala.math.pow(2, maxDepth).toInt - 1 + // Initialize an array to hold filters applied to points for each node. + val filters = new Array[List[Filter]](maxNumNodes) + // The filter at the top node is an empty list. + filters(0) = List() + // Initialize an array to hold parent impurity calculations for each node. + val parentImpurities = new Array[Double](maxNumNodes) + // dummy value for top node (updated during first split calculation) + val nodes = new Array[Node](maxNumNodes) + + + /* + * The main idea here is to perform level-wise training of the decision tree nodes thus + * reducing the passes over the data from l to log2(l) where l is the total number of nodes. + * Each data sample is checked for validity w.r.t to each node at a given level -- i.e., + * the sample is only used for the split calculation at the node if the sampled would have + * still survived the filters of the parent nodes. + */ + + // TODO: Convert for loop to while loop + breakable { + for (level <- 0 until maxDepth) { + + logDebug("#####################################") + logDebug("level = " + level) + logDebug("#####################################") + + // Find best split for all nodes at a level. + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, + level, filters, splits, bins) + + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + // Extract info for nodes at the current level. + extractNodeInfo(nodeSplitStats, level, index, nodes) + // Extract info for nodes at the next lower level. + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, + filters) + logDebug("final best split = " + nodeSplitStats._1) + } + require(scala.math.pow(2, level) == splitsStatsForLevel.length) + // Check whether all the nodes at the current level at leaves. + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) + logDebug("all leaf = " + allLeaf) + if (allLeaf) break // no more tree construction + } + } + + // Initialize the top or root node of the tree. + val topNode = nodes(0) + // Build the full tree using the node info calculated in the level-wise best split calculations. + topNode.build(nodes) + + new DecisionTreeModel(topNode, strategy.algo) + } + + /** + * Extract the decision tree node information for the given tree level and node index + */ + private def extractNodeInfo( + nodeSplitStats: (Split, InformationGainStats), + level: Int, + index: Int, + nodes: Array[Node]): Unit = { + val split = nodeSplitStats._1 + val stats = nodeSplitStats._2 + val nodeIndex = scala.math.pow(2, level).toInt - 1 + index + val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth - 1) + val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats)) + logDebug("Node = " + node) + nodes(nodeIndex) = node + } + + /** + * Extract the decision tree node information for the children of the node + */ + private def extractInfoForLowerLevels( + level: Int, + index: Int, + maxDepth: Int, + nodeSplitStats: (Split, InformationGainStats), + parentImpurities: Array[Double], + filters: Array[List[Filter]]): Unit = { + // 0 corresponds to the left child node and 1 corresponds to the right child node. + // TODO: Convert to while loop + for (i <- 0 to 1) { + // Calculate the index of the node from the node level and the index at the current level. + val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i + if (level < maxDepth - 1) { + val impurity = if (i == 0) { + nodeSplitStats._2.leftImpurity + } else { + nodeSplitStats._2.rightImpurity + } + logDebug("nodeIndex = " + nodeIndex + ", impurity = " + impurity) + // noting the parent impurities + parentImpurities(nodeIndex) = impurity + // noting the parents filters for the child nodes + val childFilter = new Filter(nodeSplitStats._1, if (i == 0) -1 else 1) + filters(nodeIndex) = childFilter :: filters((nodeIndex - 1) / 2) + for (filter <- filters(nodeIndex)) { + logDebug("Filter = " + filter) + } + } + } + } +} + +object DecisionTree extends Serializable with Logging { + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. The parameters for the algorithm are specified using the strategy parameter. + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy The configuration parameters for the tree algorithm which specify the type + * of algorithm (classification, regression, etc.), feature type (continuous, + * categorical), depth of the tree, quantile calculation strategy, etc. + * @return a DecisionTreeModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = { + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The method supports binary classification and regression. For the + * binary classification, the label for each instance should either be 0 or 1 to denote the two + * classes. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data + * @param algo algorithm, classification or regression + * @param impurity impurity criterion used for information gain calculation + * @param maxDepth maxDepth maximum depth of the tree + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int): DecisionTreeModel = { + val strategy = new Strategy(algo,impurity,maxDepth) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + + /** + * Method to train a decision tree model where the instances are represented as an RDD of + * (label, features) pairs. The decision tree method supports binary classification and + * regression. For the binary classification, the label for each instance should either be 0 or + * 1 to denote the two classes. The method also supports categorical features inputs where the + * number of categories can specified using the categoricalFeaturesInfo option. + * + * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as + * training data for DecisionTree + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and + * the number of discrete values they take. For example, + * an entry (n -> k) implies the feature n is categorical with k + * categories 0, 1, 2, ... , k-1. It's important to note that + * features are zero-indexed. + * @return a DecisionTreeModel that can be used for prediction + */ + def train( + input: RDD[LabeledPoint], + algo: Algo, + impurity: Impurity, + maxDepth: Int, + maxBins: Int, + quantileCalculationStrategy: QuantileStrategy, + categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = { + val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy, + categoricalFeaturesInfo) + new DecisionTree(strategy).train(input: RDD[LabeledPoint]) + } + + private val InvalidBinIndex = -1 + + /** + * Returns an array of optimal splits for all nodes at a given level + * + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param parentImpurities Impurities for all parent nodes for the current level + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @param level Level of the tree + * @param filters Filters for all nodes at a given level + * @param splits possible splits for all features + * @param bins possible bins for all features + * @return array of splits with best splits for all nodes at a given level. + */ + protected[tree] def findBestSplits( + input: RDD[LabeledPoint], + parentImpurities: Array[Double], + strategy: Strategy, + level: Int, + filters: Array[List[Filter]], + splits: Array[Array[Split]], + bins: Array[Array[Bin]]): Array[(Split, InformationGainStats)] = { + + /* + * The high-level description for the best split optimizations are noted here. + * + * *Level-wise training* + * We perform bin calculations for all nodes at the given level to avoid making multiple + * passes over the data. Thus, for a slightly increased computation and storage cost we save + * several iterations over the data especially at higher levels of the decision tree. + * + * *Bin-wise computation* + * We use a bin-wise best split computation strategy instead of a straightforward best split + * computation strategy. Instead of analyzing each sample for contribution to the left/right + * child node impurity of every split, we first categorize each feature of a sample into a + * bin. Each bin is an interval between a low and high split. Since each splits, and thus bin, + * is ordered (read ordering for categorical variables in the findSplitsBins method), + * we exploit this structure to calculate aggregates for bins and then use these aggregates + * to calculate information gain for each split. + * + * *Aggregation over partitions* + * Instead of performing a flatMap/reduceByKey operation, we exploit the fact that we know + * the number of splits in advance. Thus, we store the aggregates (at the appropriate + * indices) in a single array for all bins and rely upon the RDD aggregate method to + * drastically reduce the communication overhead. + */ + + // common calculations for multiple nested methods + val numNodes = scala.math.pow(2, level).toInt + logDebug("numNodes = " + numNodes) + // Find the number of features by looking at the first sample. + val numFeatures = input.first().features.size + logDebug("numFeatures = " + numFeatures) + val numBins = bins(0).length + logDebug("numBins = " + numBins) + + /** Find the filters used before reaching the current code. */ + def findParentFilters(nodeIndex: Int): List[Filter] = { + if (level == 0) { + List[Filter]() + } else { + val nodeFilterIndex = scala.math.pow(2, level).toInt - 1 + nodeIndex + filters(nodeFilterIndex) + } + } + + /** + * Find whether the sample is valid input for the current node, i.e., whether it passes through + * all the filters for the current node. + */ + def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + // leaf + if ((level > 0) & (parentFilters.length == 0)) { + return false + } + + // Apply each filter and check sample validity. Return false when invalid condition found. + for (filter <- parentFilters) { + val features = labeledPoint.features + val featureIndex = filter.split.feature + val threshold = filter.split.threshold + val comparison = filter.comparison + val categories = filter.split.categories + val isFeatureContinuous = filter.split.featureType == Continuous + val feature = features(featureIndex) + if (isFeatureContinuous) { + comparison match { + case -1 => if (feature > threshold) return false + case 1 => if (feature <= threshold) return false + } + } else { + val containsFeature = categories.contains(feature) + comparison match { + case -1 => if (!containsFeature) return false + case 1 => if (containsFeature) return false + } + + } + } + + // Return true when the sample is valid for all filters. + true + } + + /** + * Find bin for one feature. + */ + def findBin( + featureIndex: Int, + labeledPoint: LabeledPoint, + isFeatureContinuous: Boolean): Int = { + val binForFeatures = bins(featureIndex) + val feature = labeledPoint.features(featureIndex) + + /** + * Binary search helper method for continuous feature. + */ + def binarySearchForBins(): Int = { + var left = 0 + var right = binForFeatures.length - 1 + while (left <= right) { + val mid = left + (right - left) / 2 + val bin = binForFeatures(mid) + val lowThreshold = bin.lowSplit.threshold + val highThreshold = bin.highSplit.threshold + if ((lowThreshold < feature) & (highThreshold >= feature)){ + return mid + } + else if (lowThreshold >= feature) { + right = mid - 1 + } + else { + left = mid + 1 + } + } + -1 + } + + /** + * Sequential search helper method to find bin for categorical feature. + */ + def sequentialBinSearchForCategoricalFeature(): Int = { + val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex) + var binIndex = 0 + while (binIndex < numCategoricalBins) { + val bin = bins(featureIndex)(binIndex) + val category = bin.category + val features = labeledPoint.features + if (category == features(featureIndex)) { + return binIndex + } + binIndex += 1 + } + -1 + } + + if (isFeatureContinuous) { + // Perform binary search for finding bin for continuous features. + val binIndex = binarySearchForBins() + if (binIndex == -1){ + throw new UnknownError("no bin was found for continuous variable.") + } + binIndex + } else { + // Perform sequential search to find bin for categorical features. + val binIndex = sequentialBinSearchForCategoricalFeature() + if (binIndex == -1){ + throw new UnknownError("no bin was found for categorical variable.") + } + binIndex + } + } + + /** + * Finds bins for all nodes (and all features) at a given level. + * For l nodes, k features the storage is as follows: + * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk, + * where b_ij is an integer between 0 and numBins - 1. + * Invalid sample is denoted by noting bin for feature 1 as -1. + */ + def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = { + // Calculate bin index and label per feature per node. + val arr = new Array[Double](1 + (numFeatures * numNodes)) + arr(0) = labeledPoint.label + var nodeIndex = 0 + while (nodeIndex < numNodes) { + val parentFilters = findParentFilters(nodeIndex) + // Find out whether the sample qualifies for the particular node. + val sampleValid = isSampleValid(parentFilters, labeledPoint) + val shift = 1 + numFeatures * nodeIndex + if (!sampleValid) { + // Mark one bin as -1 is sufficient. + arr(shift) = InvalidBinIndex + } else { + var featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous) + featureIndex += 1 + } + } + nodeIndex += 1 + } + arr + } + + /** + * Performs a sequential aggregation over a partition for classification. For l nodes, + * k features, either the left count or the right count of one of the p bins is + * incremented based upon whether the feature is classified as 0 or 1. + * + * @param agg Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures*numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 2 * numSplits * numFeatures * numNodes for classification + */ + def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update the left or right count for one bin. + val aggShift = 2 * numBins * numFeatures * nodeIndex + val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2 + label match { + case 0.0 => agg(aggIndex) = agg(aggIndex) + 1 + case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1 + } + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition for regression. For l nodes, k features, + * the count, sum, sum of squares of one of the p bins is incremented. + * + * @param agg Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for classification + * @param arr Array[Double] of size 1 + (numFeatures * numNodes) + * @return Array[Double] storing aggregate calculation of size + * 3 * numSplits * numFeatures * numNodes for regression + */ + def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) { + // Iterate over all nodes. + var nodeIndex = 0 + while (nodeIndex < numNodes) { + // Check whether the instance was valid for this nodeIndex. + val validSignalIndex = 1 + numFeatures * nodeIndex + val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex + if (isSampleValidForNode) { + // actual class label + val label = arr(0) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Find the bin index for this feature. + val arrShift = 1 + numFeatures * nodeIndex + val arrIndex = arrShift + featureIndex + // Update count, sum, and sum^2 for one bin. + val aggShift = 3 * numBins * numFeatures * nodeIndex + val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3 + agg(aggIndex) = agg(aggIndex) + 1 + agg(aggIndex + 1) = agg(aggIndex + 1) + label + agg(aggIndex + 2) = agg(aggIndex + 2) + label*label + featureIndex += 1 + } + } + nodeIndex += 1 + } + } + + /** + * Performs a sequential aggregation over a partition. + */ + def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = { + strategy.algo match { + case Classification => classificationBinSeqOp(arr, agg) + case Regression => regressionBinSeqOp(arr, agg) + } + agg + } + + // Calculate bin aggregate length for classification or regression. + val binAggregateLength = strategy.algo match { + case Classification => 2 * numBins * numFeatures * numNodes + case Regression => 3 * numBins * numFeatures * numNodes + } + logDebug("binAggregateLength = " + binAggregateLength) + + /** + * Combines the aggregates from partitions. + * @param agg1 Array containing aggregates from one or more partitions + * @param agg2 Array containing aggregates from one or more partitions + * @return Combined aggregate from agg1 and agg2 + */ + def binCombOp(agg1: Array[Double], agg2: Array[Double]): Array[Double] = { + var index = 0 + val combinedAggregate = new Array[Double](binAggregateLength) + while (index < binAggregateLength) { + combinedAggregate(index) = agg1(index) + agg2(index) + index += 1 + } + combinedAggregate + } + + // Find feature bins for all nodes at a level. + val binMappedRDD = input.map(x => findBinsForLevel(x)) + + // Calculate bin aggregates. + val binAggregates = { + binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp) + } + logDebug("binAggregates.length = " + binAggregates.length) + + /** + * Calculates the information gain for all splits based upon left/right split aggregates. + * @param leftNodeAgg left node aggregates + * @param featureIndex feature index + * @param splitIndex split index + * @param rightNodeAgg right node aggregate + * @param topImpurity impurity of the parent node + * @return information gain and statistics for all splits + */ + def calculateGainForSplit( + leftNodeAgg: Array[Array[Double]], + featureIndex: Int, + splitIndex: Int, + rightNodeAgg: Array[Array[Double]], + topImpurity: Double): InformationGainStats = { + strategy.algo match { + case Classification => + val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex) + val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1) + val leftCount = left0Count + left1Count + + val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex) + val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1) + val rightCount = right0Count + right1Count + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0) + } + + val leftImpurity = strategy.impurity.calculate(left0Count, left1Count) + val rightImpurity = strategy.impurity.calculate(right0Count, right1Count) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (left1Count + right1Count) / (leftCount + rightCount) + + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + case Regression => + val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex) + val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1) + val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2) + + val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex) + val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1) + val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2) + + val impurity = { + if (level > 0) { + topImpurity + } else { + // Calculate impurity for root node. + val count = leftCount + rightCount + val sum = leftSum + rightSum + val sumSquares = leftSumSquares + rightSumSquares + strategy.impurity.calculate(count, sum, sumSquares) + } + } + + if (leftCount == 0) { + return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, + rightSum / rightCount) + } + if (rightCount == 0) { + return new InformationGainStats(0, topImpurity ,topImpurity, + Double.MinValue, leftSum / leftCount) + } + + val leftImpurity = strategy.impurity.calculate(leftCount, leftSum, leftSumSquares) + val rightImpurity = strategy.impurity.calculate(rightCount, rightSum, rightSumSquares) + + val leftWeight = leftCount.toDouble / (leftCount + rightCount) + val rightWeight = rightCount.toDouble / (leftCount + rightCount) + + val gain = { + if (level > 0) { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } else { + impurity - leftWeight * leftImpurity - rightWeight * rightImpurity + } + } + + val predict = (leftSum + rightSum) / (leftCount + rightCount) + new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict) + } + } + + /** + * Extracts left and right split aggregates. + * @param binData Array[Double] of size 2*numFeatures*numSplits + * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double], + * Array[Double]) where each array is of size(numFeature,2*(numSplits-1)) + */ + def extractLeftRightNodeAggregates( + binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = { + strategy.algo match { + case Classification => + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1)) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 2 * featureIndex * numBins + + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(2 * (numBins - 2)) + = binData(shift + (2 * (numBins - 1))) + rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1) + = binData(shift + (2 * (numBins - 1)) + 1) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2) + leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) + + leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) = + binData(shift + (2 *(numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (2* (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1) + + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + case Regression => + // Initialize left and right split aggregates. + val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1)) + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // shift for this featureIndex + val shift = 3 * featureIndex * numBins + // left node aggregate for the lowest split + leftNodeAgg(featureIndex)(0) = binData(shift + 0) + leftNodeAgg(featureIndex)(1) = binData(shift + 1) + leftNodeAgg(featureIndex)(2) = binData(shift + 2) + + // right node aggregate for the highest split + rightNodeAgg(featureIndex)(3 * (numBins - 2)) = + binData(shift + (3 * (numBins - 1))) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) = + binData(shift + (3 * (numBins - 1)) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) = + binData(shift + (3 * (numBins - 1)) + 2) + + // Iterate over all splits. + var splitIndex = 1 + while (splitIndex < numBins - 1) { + // calculating left node aggregate for a split as a sum of left node aggregate of a + // lower split and the left bin aggregate of a bin where the split is a high split + leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3) + leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1) + leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) + + leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2) + + // calculating right node aggregate for a split as a sum of right node aggregate of a + // higher split and the right bin aggregate of a bin where the split is a low split + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) = + binData(shift + (3 * (numBins - 2 - splitIndex))) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex)) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 1)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1) + rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) = + binData(shift + (3 * (numBins - 2 - splitIndex) + 2)) + + rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2) + + splitIndex += 1 + } + featureIndex += 1 + } + (leftNodeAgg, rightNodeAgg) + } + } + + /** + * Calculates information gain for all nodes splits. + */ + def calculateGainsForAllNodeSplits( + leftNodeAgg: Array[Array[Double]], + rightNodeAgg: Array[Array[Double]], + nodeImpurity: Double): Array[Array[InformationGainStats]] = { + val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1) + + for (featureIndex <- 0 until numFeatures) { + for (splitIndex <- 0 until numBins - 1) { + gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex, + splitIndex, rightNodeAgg, nodeImpurity) + } + } + gains + } + + /** + * Find the best split for a node. + * @param binData Array[Double] of size 2 * numSplits * numFeatures + * @param nodeImpurity impurity of the top node + * @return tuple of split and information gain + */ + def binsToBestSplit( + binData: Array[Double], + nodeImpurity: Double): (Split, InformationGainStats) = { + + logDebug("node impurity = " + nodeImpurity) + + // Extract left right node aggregates. + val (leftNodeAgg, rightNodeAgg) = extractLeftRightNodeAggregates(binData) + + // Calculate gains for all splits. + val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity) + + val (bestFeatureIndex,bestSplitIndex, gainStats) = { + // Initialize with infeasible values. + var bestFeatureIndex = Int.MinValue + var bestSplitIndex = Int.MinValue + var bestGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, -1.0) + // Iterate over features. + var featureIndex = 0 + while (featureIndex < numFeatures) { + // Iterate over all splits. + var splitIndex = 0 + while (splitIndex < numBins - 1) { + val gainStats = gains(featureIndex)(splitIndex) + if (gainStats.gain > bestGainStats.gain) { + bestGainStats = gainStats + bestFeatureIndex = featureIndex + bestSplitIndex = splitIndex + } + splitIndex += 1 + } + featureIndex += 1 + } + (bestFeatureIndex, bestSplitIndex, bestGainStats) + } + + logDebug("best split bin = " + bins(bestFeatureIndex)(bestSplitIndex)) + logDebug("best split bin = " + splits(bestFeatureIndex)(bestSplitIndex)) + + (splits(bestFeatureIndex)(bestSplitIndex), gainStats) + } + + /** + * Get bin data for one node. + */ + def getBinDataForNode(node: Int): Array[Double] = { + strategy.algo match { + case Classification => + val shift = 2 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures) + binsForNode + case Regression => + val shift = 3 * node * numBins * numFeatures + val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures) + binsForNode + } + } + + // Calculate best splits for all nodes at a given level + val bestSplits = new Array[(Split, InformationGainStats)](numNodes) + // Iterating over all nodes at this level + var node = 0 + while (node < numNodes) { + val nodeImpurityIndex = scala.math.pow(2, level).toInt - 1 + node + val binsForNode: Array[Double] = getBinDataForNode(node) + logDebug("nodeImpurityIndex = " + nodeImpurityIndex) + val parentNodeImpurity = parentImpurities(nodeImpurityIndex) + logDebug("node impurity = " + parentNodeImpurity) + bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity) + node += 1 + } + + bestSplits + } + + /** + * Returns split and bins for decision tree calculation. + * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data + * for DecisionTree + * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing + * parameters for construction the DecisionTree + * @return a tuple of (splits,bins) where splits is an Array of [org.apache.spark.mllib.tree + * .model.Split] of size (numFeatures, numSplits-1) and bins is an Array of [org.apache + * .spark.mllib.tree.model.Bin] of size (numFeatures, numSplits1) + */ + protected[tree] def findSplitsBins( + input: RDD[LabeledPoint], + strategy: Strategy): (Array[Array[Split]], Array[Array[Bin]]) = { + val count = input.count() + + // Find the number of features by looking at the first sample + val numFeatures = input.take(1)(0).features.size + + val maxBins = strategy.maxBins + val numBins = if (maxBins <= count) maxBins else count.toInt + logDebug("numBins = " + numBins) + + /* + * TODO: Add a require statement ensuring #bins is always greater than the categories. + * It's a limitation of the current implementation but a reasonable trade-off since features + * with large number of categories get favored over continuous features. + */ + if (strategy.categoricalFeaturesInfo.size > 0) { + val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + require(numBins >= maxCategoriesForFeatures) + } + + // Calculate the number of sample for approximate quantile calculation. + val requiredSamples = numBins*numBins + val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0 + logDebug("fraction of data used for calculating quantiles = " + fraction) + + // sampled input for RDD calculation + val sampledInput = input.sample(false, fraction, new XORShiftRandom().nextInt()).collect() + val numSamples = sampledInput.length + + val stride: Double = numSamples.toDouble / numBins + logDebug("stride = " + stride) + + strategy.quantileCalculationStrategy match { + case Sort => + val splits = Array.ofDim[Split](numFeatures, numBins - 1) + val bins = Array.ofDim[Bin](numFeatures, numBins) + + // Find all splits. + + // Iterate over all features. + var featureIndex = 0 + while (featureIndex < numFeatures){ + // Check whether the feature is continuous. + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { + val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted + val stride: Double = numSamples.toDouble / numBins + logDebug("stride = " + stride) + for (index <- 0 until numBins - 1) { + val sampleIndex = (index + 1) * stride.toInt + val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List()) + splits(featureIndex)(index) = split + } + } else { + val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex) + require(maxFeatureValue < numBins, "number of categories should be less than number " + + "of bins") + + // For categorical variables, each bin is a category. The bins are sorted and they + // are ordered by calculating the centroid of their corresponding labels. + val centroidForCategories = + sampledInput.map(lp => (lp.features(featureIndex),lp.label)) + .groupBy(_._1) + .mapValues(x => x.map(_._2).sum / x.map(_._1).length) + + // Check for missing categorical variables and putting them last in the sorted list. + val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]() + for (i <- 0 until maxFeatureValue) { + if (centroidForCategories.contains(i)) { + fullCentroidForCategories(i) = centroidForCategories(i) + } else { + fullCentroidForCategories(i) = Double.MaxValue + } + } + + // bins sorted by centroids + val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2) + + logDebug("centriod for categorical variable = " + categoriesSortedByCentroid) + + var categoriesForSplit = List[Double]() + categoriesSortedByCentroid.iterator.zipWithIndex.foreach { + case ((key, value), index) => + categoriesForSplit = key :: categoriesForSplit + splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical, + categoriesForSplit) + bins(featureIndex)(index) = { + if (index == 0) { + new Bin(new DummyCategoricalSplit(featureIndex, Categorical), + splits(featureIndex)(0), Categorical, key) + } else { + new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Categorical, key) + } + } + } + } + featureIndex += 1 + } + + // Find all bins. + featureIndex = 0 + while (featureIndex < numFeatures) { + val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty + if (isFeatureContinuous) { // Bins for categorical variables are already assigned. + bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous), + splits(featureIndex)(0), Continuous, Double.MinValue) + for (index <- 1 until numBins - 1){ + val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index), + Continuous, Double.MinValue) + bins(featureIndex)(index) = bin + } + bins(featureIndex)(numBins-1) = new Bin(splits(featureIndex)(numBins-2), + new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue) + } + featureIndex += 1 + } + (splits,bins) + case MinMax => + throw new UnsupportedOperationException("minmax not supported yet.") + case ApproxHist => + throw new UnsupportedOperationException("approximate histogram not supported yet.") + } + } + + private val usage = """ + Usage: DecisionTreeRunner [slices] --algo --trainDataDir path --testDataDir path --maxDepth num [--impurity ] [--maxBins num] + """ + + def main(args: Array[String]) { + + if (args.length < 2) { + System.err.println(usage) + System.exit(1) + } + + val sc = new SparkContext(args(0), "DecisionTree") + + val argList = args.toList.drop(1) + type OptionMap = Map[Symbol, Any] + + def nextOption(map : OptionMap, list: List[String]): OptionMap = { + list match { + case Nil => map + case "--algo" :: string :: tail => nextOption(map ++ Map('algo -> string), tail) + case "--impurity" :: string :: tail => nextOption(map ++ Map('impurity -> string), tail) + case "--maxDepth" :: string :: tail => nextOption(map ++ Map('maxDepth -> string), tail) + case "--maxBins" :: string :: tail => nextOption(map ++ Map('maxBins -> string), tail) + case "--trainDataDir" :: string :: tail => nextOption(map ++ Map('trainDataDir -> string) + , tail) + case "--testDataDir" :: string :: tail => nextOption(map ++ Map('testDataDir -> string), + tail) + case string :: Nil => nextOption(map ++ Map('infile -> string), list.tail) + case option :: tail => logError("Unknown option " + option) + sys.exit(1) + } + } + val options = nextOption(Map(), argList) + logDebug(options.toString()) + + // Load training data. + val trainData = loadLabeledData(sc, options.get('trainDataDir).get.toString) + + // Identify the type of algorithm. + val algoStr = options.get('algo).get.toString + val algo = algoStr match { + case "Classification" => Classification + case "Regression" => Regression + } + + // Identify the type of impurity. + val impurityStr = options.getOrElse('impurity, + if (algo == Classification) "Gini" else "Variance").toString + val impurity = impurityStr match { + case "Gini" => Gini + case "Entropy" => Entropy + case "Variance" => Variance + } + + val maxDepth = options.getOrElse('maxDepth, "1").toString.toInt + val maxBins = options.getOrElse('maxBins, "100").toString.toInt + + val strategy = new Strategy(algo, impurity, maxDepth, maxBins) + val model = DecisionTree.train(trainData, strategy) + + // Load test data. + val testData = loadLabeledData(sc, options.get('testDataDir).get.toString) + + // Measure algorithm accuracy + if (algo == Classification) { + val accuracy = accuracyScore(model, testData) + logDebug("accuracy = " + accuracy) + } + + if (algo == Regression) { + val mse = meanSquaredError(model, testData) + logDebug("mean square error = " + mse) + } + + sc.stop() + } + + /** + * Load labeled data from a file. The data format used here is + * , ..., + * where , are feature values in Double and is the corresponding label as Double. + * + * @param sc SparkContext + * @param dir Directory to the input data files. + * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is + * the label, and the second element represents the feature values (an array of Double). + */ + private def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { + sc.textFile(dir).map { line => + val parts = line.trim().split(",") + val label = parts(0).toDouble + val features = Vectors.dense(parts.slice(1,parts.length).map(_.toDouble)) + LabeledPoint(label, features) + } + } + + // TODO: Port this method to a generic metrics package. + /** + * Calculates the classifier accuracy. + */ + private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint], + threshold: Double = 0.5): Double = { + def predictedValue(features: Vector) = { + if (model.predict(features) < threshold) 0.0 else 1.0 + } + val correctCount = data.filter(y => predictedValue(y.features) == y.label).count() + val count = data.count() + logDebug("correct prediction count = " + correctCount) + logDebug("data count = " + count) + correctCount.toDouble / count + } + + // TODO: Port this method to a generic metrics package + /** + * Calculates the mean squared error for regression. + */ + private def meanSquaredError(tree: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md new file mode 100644 index 0000000000000..0fd71aa9735bc --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/README.md @@ -0,0 +1,17 @@ +This package contains the default implementation of the decision tree algorithm. + +The decision tree algorithm supports: ++ Binary classification ++ Regression ++ Information loss calculation with entropy and gini for classification and variance for regression ++ Both continuous and categorical features + +# Tree improvements ++ Node model pruning ++ Printing to dot files + +# Future Ensemble Extensions + ++ Random forests ++ Boosting ++ Extremely randomized trees diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyDenseMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala similarity index 74% rename from mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyDenseMatrix.scala rename to mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index e4ef3c58e8680..79a01f58319e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/TallSkinnyDenseMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.mllib.linalg - -import org.apache.spark.rdd.RDD +package org.apache.spark.mllib.tree.configuration +import org.apache.spark.annotation.Experimental /** - * Class that represents a dense matrix - * - * @param rows RDD of rows - * @param m number of rows - * @param n number of columns + * :: Experimental :: + * Enum to select the algorithm for the decision tree */ -case class TallSkinnyDenseMatrix(val rows: RDD[MatrixRow], val m: Int, val n: Int) +@Experimental +object Algo extends Enumeration { + type Algo = Value + val Classification, Regression = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixEntry.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala similarity index 72% rename from mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixEntry.scala rename to mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala index 416996fcbe760..f4c877232750f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/MatrixEntry.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/FeatureType.scala @@ -15,13 +15,16 @@ * limitations under the License. */ -package org.apache.spark.mllib.linalg +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.Experimental /** - * Class that represents an entry in a sparse matrix of doubles. - * - * @param i row index (0 indexing used) - * @param j column index (0 indexing used) - * @param mval value of entry in matrix + * :: Experimental :: + * Enum to describe whether a feature is "continuous" or "categorical" */ -case class MatrixEntry(val i: Int, val j: Int, val mval: Double) +@Experimental +object FeatureType extends Enumeration { + type FeatureType = Value + val Continuous, Categorical = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SparseMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala similarity index 72% rename from mllib/src/main/scala/org/apache/spark/mllib/linalg/SparseMatrix.scala rename to mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala index cbd1a2a5a4bd8..7da976e55a722 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/SparseMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/QuantileStrategy.scala @@ -15,16 +15,16 @@ * limitations under the License. */ -package org.apache.spark.mllib.linalg - -import org.apache.spark.rdd.RDD +package org.apache.spark.mllib.tree.configuration +import org.apache.spark.annotation.Experimental /** - * Class that represents a sparse matrix - * - * @param data RDD of nonzero entries - * @param m number of rows - * @param n numner of columns + * :: Experimental :: + * Enum for selecting the quantile calculation strategy */ -case class SparseMatrix(val data: RDD[MatrixEntry], val m: Int, val n: Int) +@Experimental +object QuantileStrategy extends Enumeration { + type QuantileStrategy = Value + val Sort, MinMax, ApproxHist = Value +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala new file mode 100644 index 0000000000000..8767aca47cd5a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.configuration + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ + +/** + * :: Experimental :: + * Stores all the configuration options for tree construction + * @param algo classification or regression + * @param impurity criterion used for information gain calculation + * @param maxDepth maximum depth of the tree + * @param maxBins maximum number of bins used for splitting features + * @param quantileCalculationStrategy algorithm for calculating quantiles + * @param categoricalFeaturesInfo A map storing information about the categorical variables and the + * number of discrete values they take. For example, an entry (n -> + * k) implies the feature n is categorical with k categories 0, + * 1, 2, ... , k-1. It's important to note that features are + * zero-indexed. + */ +@Experimental +class Strategy ( + val algo: Algo, + val impurity: Impurity, + val maxDepth: Int, + val maxBins: Int = 100, + val quantileCalculationStrategy: QuantileStrategy = Sort, + val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int]()) extends Serializable diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala new file mode 100644 index 0000000000000..60f43e9278d2a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental} + +/** + * :: Experimental :: + * Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during + * binary classification. + */ +@Experimental +object Entropy extends Impurity { + + private[tree] def log2(x: Double) = scala.math.log(x) / scala.math.log(2) + + /** + * :: DeveloperApi :: + * entropy calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return entropy value + */ + @DeveloperApi + override def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + -(f0 * log2(f0)) - (f1 * log2(f1)) + } + } + + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("Entropy.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala new file mode 100644 index 0000000000000..c51d76d9b4c5b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental} + +/** + * :: Experimental :: + * Class for calculating the + * [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]] + * during binary classification. + */ +@Experimental +object Gini extends Impurity { + + /** + * :: DeveloperApi :: + * Gini coefficient calculation + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return Gini coefficient value + */ + @DeveloperApi + override def calculate(c0: Double, c1: Double): Double = { + if (c0 == 0 || c1 == 0) { + 0 + } else { + val total = c0 + c1 + val f0 = c0 / total + val f1 = c1 / total + 1 - f0 * f0 - f1 * f1 + } + } + + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = + throw new UnsupportedOperationException("Gini.calculate") +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala new file mode 100644 index 0000000000000..8eab247cf0932 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental} + +/** + * :: Experimental :: + * Trait for calculating information gain. + */ +@Experimental +trait Impurity extends Serializable { + + /** + * :: DeveloperApi :: + * information calculation for binary classification + * @param c0 count of instances with label 0 + * @param c1 count of instances with label 1 + * @return information value + */ + @DeveloperApi + def calculate(c0 : Double, c1 : Double): Double + + /** + * :: DeveloperApi :: + * information calculation for regression + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + * @return information value + */ + @DeveloperApi + def calculate(count: Double, sum: Double, sumSquares: Double): Double +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala new file mode 100644 index 0000000000000..47d07122af30f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impurity + +import org.apache.spark.annotation.{DeveloperApi, Experimental} + +/** + * :: Experimental :: + * Class for calculating variance during regression + */ +@Experimental +object Variance extends Impurity { + override def calculate(c0: Double, c1: Double): Double = + throw new UnsupportedOperationException("Variance.calculate") + + /** + * :: DeveloperApi :: + * variance calculation + * @param count number of instances + * @param sum sum of labels + * @param sumSquares summation of squares of the labels + */ + @DeveloperApi + override def calculate(count: Double, sum: Double, sumSquares: Double): Double = { + val squaredLoss = sumSquares - (sum * sum) / count + squaredLoss / count + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala new file mode 100644 index 0000000000000..2d71e1e366069 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.mllib.tree.configuration.FeatureType._ + +/** + * Used for "binning" the features bins for faster best split calculation. For a continuous + * feature, a bin is determined by a low and a high "split". For a categorical feature, + * the a bin is determined using a single label value (category). + * @param lowSplit signifying the lower threshold for the continuous feature to be + * accepted in the bin + * @param highSplit signifying the upper threshold for the continuous feature to be + * accepted in the bin + * @param featureType type of feature -- categorical or continuous + * @param category categorical label value accepted in the bin + */ +private[tree] +case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala new file mode 100644 index 0000000000000..bf692ca8c4bd7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vector + +/** + * :: Experimental :: + * Model to store the decision tree parameters + * @param topNode root node + * @param algo algorithm type -- classification or regression + */ +@Experimental +class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable { + + /** + * Predict values for a single data point using the model trained. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + def predict(features: Vector): Double = { + topNode.predictIfLeaf(features) + } + + /** + * Predict values for the given data set using the model trained. + * + * @param features RDD representing data points to be predicted + * @return RDD[Int] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = { + features.map(x => predict(x)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala new file mode 100644 index 0000000000000..2deaf4ae8dcab --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Filter.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +/** + * Filter specifying a split and type of comparison to be applied on features + * @param split split specifying the feature index, type and threshold + * @param comparison integer specifying <,=,> + */ +private[tree] case class Filter(split: Split, comparison: Int) { + // Comparison -1,0,1 signifies <.=,> + override def toString = " split = " + split + "comparison = " + comparison +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala new file mode 100644 index 0000000000000..cc8a24cce9614 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.DeveloperApi + +/** + * :: DeveloperApi :: + * Information gain statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param leftImpurity left node impurity + * @param rightImpurity right node impurity + * @param predict predicted value + */ +@DeveloperApi +class InformationGainStats( + val gain: Double, + val impurity: Double, + val leftImpurity: Double, + val rightImpurity: Double, + val predict: Double) extends Serializable { + + override def toString = { + "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f" + .format(gain, impurity, leftImpurity, rightImpurity, predict) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala new file mode 100644 index 0000000000000..682f213f411a7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Logging +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.linalg.Vector + +/** + * :: DeveloperApi :: + * Node in a decision tree + * @param id integer node id + * @param predict predicted value at the node + * @param isLeaf whether the leaf is a node + * @param split split to calculate left and right nodes + * @param leftNode left child + * @param rightNode right child + * @param stats information gain stats + */ +@DeveloperApi +class Node ( + val id: Int, + val predict: Double, + val isLeaf: Boolean, + val split: Option[Split], + var leftNode: Option[Node], + var rightNode: Option[Node], + val stats: Option[InformationGainStats]) extends Serializable with Logging { + + override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + + "split = " + split + ", stats = " + stats + + /** + * build the left node and right nodes if not leaf + * @param nodes array of nodes + */ + def build(nodes: Array[Node]): Unit = { + + logDebug("building node " + id + " at level " + + (scala.math.log(id + 1)/scala.math.log(2)).toInt ) + logDebug("id = " + id + ", split = " + split) + logDebug("stats = " + stats) + logDebug("predict = " + predict) + if (!isLeaf) { + val leftNodeIndex = id * 2 + 1 + val rightNodeIndex = id * 2 + 2 + leftNode = Some(nodes(leftNodeIndex)) + rightNode = Some(nodes(rightNodeIndex)) + leftNode.get.build(nodes) + rightNode.get.build(nodes) + } + } + + /** + * predict value if node is not leaf + * @param feature feature value + * @return predicted value + */ + def predictIfLeaf(feature: Vector) : Double = { + if (isLeaf) { + predict + } else{ + if (split.get.featureType == Continuous) { + if (feature(split.get.feature) <= split.get.threshold) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } else { + if (split.get.categories.contains(feature(split.get.feature))) { + leftNode.get.predictIfLeaf(feature) + } else { + rightNode.get.predictIfLeaf(feature) + } + } + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala new file mode 100644 index 0000000000000..d7ffd386c05ee --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType + +/** + * :: DeveloperApi :: + * Split applied to a feature + * @param feature feature index + * @param threshold threshold for continuous feature + * @param featureType type of feature -- categorical or continuous + * @param categories accepted values for categorical variables + */ +@DeveloperApi +case class Split( + feature: Int, + threshold: Double, + featureType: FeatureType, + categories: List[Double]) { + + override def toString = + "Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType + + ", categories = " + categories +} + +/** + * Split with minimum threshold for continuous features. Helps with the smallest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +private[tree] class DummyLowSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MinValue, featureType, List()) + +/** + * Split with maximum threshold for continuous features. Helps with the highest bin creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +private[tree] class DummyHighSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) + +/** + * Split with no acceptable feature values for categorical features. Helps with the first bin + * creation. + * @param feature feature index + * @param featureType type of feature -- categorical or continuous + */ +private[tree] class DummyCategoricalSplit(feature: Int, featureType: FeatureType) + extends Split(feature, Double.MaxValue, featureType, List()) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala index 8b55bce7c4bec..45f95482a1def 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala @@ -17,23 +17,24 @@ package org.apache.spark.mllib.util +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint /** + * :: DeveloperApi :: * A collection of methods used to validate data before applying ML algorithms. */ +@DeveloperApi object DataValidators extends Logging { /** * Function to check if labels used for classification are either zero or one. * - * @param data - input data set that needs to be checked - * * @return True if labels are all zero or one, false otherwise. */ - val classificationLabels: RDD[LabeledPoint] => Boolean = { data => + val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data => val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() if (numInvalid != 0) { logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala index 9109189dff52f..6eaebaf7dba9f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/KMeansDataGenerator.scala @@ -19,15 +19,17 @@ package org.apache.spark.mllib.util import scala.util.Random +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD /** + * :: DeveloperApi :: * Generate test data for KMeans. This class first chooses k cluster centers * from a d-dimensional Gaussian distribution scaled by factor r and then creates a Gaussian * cluster with scale 1 around each center. */ - +@DeveloperApi object KMeansDataGenerator { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LAUtils.scala deleted file mode 100644 index afe081295bfae..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LAUtils.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.util - -import org.apache.spark.SparkContext._ - -import org.apache.spark.mllib.linalg._ - -/** - * Helper methods for linear algebra - */ -object LAUtils { - /** - * Convert a SparseMatrix into a TallSkinnyDenseMatrix - * - * @param sp Sparse matrix to be converted - * @return dense version of the input - */ - def sparseToTallSkinnyDense(sp: SparseMatrix): TallSkinnyDenseMatrix = { - val m = sp.m - val n = sp.n - val rows = sp.data.map(x => (x.i, (x.j, x.mval))).groupByKey().map { - case (i, cols) => - val rowArray = Array.ofDim[Double](n) - var j = 0 - while (j < cols.size) { - rowArray(cols(j)._1) = cols(j)._2 - j += 1 - } - MatrixRow(i, rowArray) - } - TallSkinnyDenseMatrix(rows, m, n) - } - - /** - * Convert a TallSkinnyDenseMatrix to a SparseMatrix - * - * @param a matrix to be converted - * @return sparse version of the input - */ - def denseToSparse(a: TallSkinnyDenseMatrix): SparseMatrix = { - val m = a.m - val n = a.n - val data = a.rows.flatMap { - mrow => Array.tabulate(n)(j => MatrixEntry(mrow.i, j, mrow.data(j))) - .filter(x => x.mval != 0) - } - SparseMatrix(data, m, n) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala new file mode 100644 index 0000000000000..f7966d3ebb613 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LabelParsers.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +/** Trait for label parsers. */ +trait LabelParser extends Serializable { + /** Parses a string label into a double label. */ + def parse(labelString: String): Double +} + +/** + * Label parser for binary labels, which outputs 1.0 (positive) if the value is greater than 0.5, + * or 0.0 (negative) otherwise. So it works with +1/-1 labeling and +1/0 labeling. + */ +object BinaryLabelParser extends LabelParser { + /** Gets the default instance of BinaryLabelParser. */ + def getInstance(): LabelParser = this + + /** + * Parses the input label into positive (1.0) if the value is greater than 0.5, + * or negative (0.0) otherwise. + */ + override def parse(labelString: String): Double = if (labelString.toDouble > 0.5) 1.0 else 0.0 +} + +/** + * Label parser for multiclass labels, which converts the input label to double. + */ +object MulticlassLabelParser extends LabelParser { + /** Gets the default instance of MulticlassLabelParser. */ + def getInstance(): LabelParser = this + + override def parse(labelString: String): Double = labelString.toDouble +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 2e03684e62861..c8e160d00c2d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -22,15 +22,19 @@ import scala.util.Random import org.jblas.DoubleMatrix +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint /** + * :: DeveloperApi :: * Generate sample data used for Linear Data. This class generates * uniformly random values for every feature and adds Gaussian noise with mean `eps` to the * response variable `Y`. */ +@DeveloperApi object LinearDataGenerator { /** @@ -74,7 +78,7 @@ object LinearDataGenerator { val y = x.map { xi => new DoubleMatrix(1, xi.length, xi: _*).dot(weightsMat) + intercept + eps * rnd.nextGaussian() } - y.zip(x).map(p => LabeledPoint(p._1, p._2)) + y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 52c4a71d621a1..c82cd8fd4641c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,15 +19,18 @@ package org.apache.spark.mllib.util import scala.util.Random +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.Vectors /** + * :: DeveloperApi :: * Generate test data for LogisticRegression. This class chooses positive labels * with probability `probOne` and scales features for positive examples by `eps`. */ - +@DeveloperApi object LogisticRegressionDataGenerator { /** @@ -54,7 +57,7 @@ object LogisticRegressionDataGenerator { val x = Array.fill[Double](nfeatures) { rnd.nextGaussian() + (y * eps) } - LabeledPoint(y, x) + LabeledPoint(y, Vectors.dense(x)) } data } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 348aba1dea5b6..3f413faca6bb4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -21,35 +21,36 @@ import scala.util.Random import org.jblas.DoubleMatrix +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD /** -* Generate RDD(s) containing data for Matrix Factorization. -* -* This method samples training entries according to the oversampling factor -* 'trainSampFact', which is a multiplicative factor of the number of -* degrees of freedom of the matrix: rank*(m+n-rank). -* -* It optionally samples entries for a testing matrix using -* 'testSampFact', the percentage of the number of training entries -* to use for testing. -* -* This method takes the following inputs: -* sparkMaster (String) The master URL. -* outputPath (String) Directory to save output. -* m (Int) Number of rows in data matrix. -* n (Int) Number of columns in data matrix. -* rank (Int) Underlying rank of data matrix. -* trainSampFact (Double) Oversampling factor. -* noise (Boolean) Whether to add gaussian noise to training data. -* sigma (Double) Standard deviation of added gaussian noise. -* test (Boolean) Whether to create testing RDD. -* testSampFact (Double) Percentage of training data to use as test data. -*/ - -object MFDataGenerator{ - + * :: DeveloperApi :: + * Generate RDD(s) containing data for Matrix Factorization. + * + * This method samples training entries according to the oversampling factor + * 'trainSampFact', which is a multiplicative factor of the number of + * degrees of freedom of the matrix: rank*(m+n-rank). + * + * It optionally samples entries for a testing matrix using + * 'testSampFact', the percentage of the number of training entries + * to use for testing. + * + * This method takes the following inputs: + * sparkMaster (String) The master URL. + * outputPath (String) Directory to save output. + * m (Int) Number of rows in data matrix. + * n (Int) Number of columns in data matrix. + * rank (Int) Underlying rank of data matrix. + * trainSampFact (Double) Oversampling factor. + * noise (Boolean) Whether to add gaussian noise to training data. + * sigma (Double) Standard deviation of added gaussian noise. + * test (Boolean) Whether to create testing RDD. + * testSampFact (Double) Percentage of training data to use as test data. + */ +@DeveloperApi +object MFDataGenerator { def main(args: Array[String]) { if (args.length < 2) { println("Usage: MFDataGenerator " + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 08cd9ab05547b..901c3180eac4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -17,15 +17,13 @@ package org.apache.spark.mllib.util +import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} + +import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ - -import org.jblas.DoubleMatrix - import org.apache.spark.mllib.regression.LabeledPoint - -import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} +import org.apache.spark.mllib.linalg.Vectors /** * Helper methods to load, save and pre-process data used in ML Lib. @@ -41,6 +39,90 @@ object MLUtils { } /** + * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint]. + * The LIBSVM format is a text-based format used by LIBSVM and LIBLINEAR. + * Each line represents a labeled sparse feature vector using the following format: + * {{{label index1:value1 index2:value2 ...}}} + * where the indices are one-based and in ascending order. + * This method parses each line into a [[org.apache.spark.mllib.regression.LabeledPoint]], + * where the feature indices are converted to zero-based. + * + * @param sc Spark context + * @param path file or directory path in any Hadoop-supported file system URI + * @param labelParser parser for labels, default: 1.0 if label > 0.5 or 0.0 otherwise + * @param numFeatures number of features, which will be determined from the input data if a + * negative value is given. The default value is -1. + * @param minSplits min number of partitions, default: sc.defaultMinSplits + * @return labeled data stored as an RDD[LabeledPoint] + */ + def loadLibSVMData( + sc: SparkContext, + path: String, + labelParser: LabelParser, + numFeatures: Int, + minSplits: Int): RDD[LabeledPoint] = { + val parsed = sc.textFile(path, minSplits) + .map(_.trim) + .filter(!_.isEmpty) + .map(_.split(' ')) + // Determine number of features. + val d = if (numFeatures >= 0) { + numFeatures + } else { + parsed.map { items => + if (items.length > 1) { + items.last.split(':')(0).toInt + } else { + 0 + } + }.reduce(math.max) + } + parsed.map { items => + val label = labelParser.parse(items.head) + val (indices, values) = items.tail.map { item => + val indexAndValue = item.split(':') + val index = indexAndValue(0).toInt - 1 + val value = indexAndValue(1).toDouble + (index, value) + }.unzip + LabeledPoint(label, Vectors.sparse(d, indices.toArray, values.toArray)) + } + } + + // Convenient methods for calling from Java. + + /** + * Loads binary labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with number of features determined automatically and the default number of partitions. + */ + def loadLibSVMData(sc: SparkContext, path: String): RDD[LabeledPoint] = + loadLibSVMData(sc, path, BinaryLabelParser, -1, sc.defaultMinSplits) + + /** + * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with the given label parser, number of features determined automatically, + * and the default number of partitions. + */ + def loadLibSVMData( + sc: SparkContext, + path: String, + labelParser: LabelParser): RDD[LabeledPoint] = + loadLibSVMData(sc, path, labelParser, -1, sc.defaultMinSplits) + + /** + * Loads labeled data in the LIBSVM format into an RDD[LabeledPoint], + * with the given label parser, number of features specified explicitly, + * and the default number of partitions. + */ + def loadLibSVMData( + sc: SparkContext, + path: String, + labelParser: LabelParser, + numFeatures: Int): RDD[LabeledPoint] = + loadLibSVMData(sc, path, labelParser, numFeatures, sc.defaultMinSplits) + + /** + * :: Experimental :: * Load labeled data from a file. The data format used here is * , ... * where , are feature values in Double and is the corresponding label as Double. @@ -50,16 +132,18 @@ object MLUtils { * @return An RDD of LabeledPoint. Each labeled point has two elements: the first element is * the label, and the second element represents the feature values (an array of Double). */ + @Experimental def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = { sc.textFile(dir).map { line => val parts = line.split(',') val label = parts(0).toDouble - val features = parts(1).trim().split(' ').map(_.toDouble) + val features = Vectors.dense(parts(1).trim().split(' ').map(_.toDouble)) LabeledPoint(label, features) } } /** + * :: Experimental :: * Save labeled data to a file. The data format used here is * , ... * where , are feature values in Double and is the corresponding label as Double. @@ -67,55 +151,12 @@ object MLUtils { * @param data An RDD of LabeledPoints containing data to be saved. * @param dir Directory to save the data. */ + @Experimental def saveLabeledData(data: RDD[LabeledPoint], dir: String) { - val dataStr = data.map(x => x.label + "," + x.features.mkString(" ")) + val dataStr = data.map(x => x.label + "," + x.features.toArray.mkString(" ")) dataStr.saveAsTextFile(dir) } - /** - * Utility function to compute mean and standard deviation on a given dataset. - * - * @param data - input data set whose statistics are computed - * @param nfeatures - number of features - * @param nexamples - number of examples in input dataset - * - * @return (yMean, xColMean, xColSd) - Tuple consisting of - * yMean - mean of the labels - * xColMean - Row vector with mean for every column (or feature) of the input data - * xColSd - Row vector standard deviation for every column (or feature) of the input data. - */ - def computeStats(data: RDD[LabeledPoint], nfeatures: Int, nexamples: Long): - (Double, DoubleMatrix, DoubleMatrix) = { - val yMean: Double = data.map { labeledPoint => labeledPoint.label }.reduce(_ + _) / nexamples - - // NOTE: We shuffle X by column here to compute column sum and sum of squares. - val xColSumSq: RDD[(Int, (Double, Double))] = data.flatMap { labeledPoint => - val nCols = labeledPoint.features.length - // Traverse over every column and emit (col, value, value^2) - Iterator.tabulate(nCols) { i => - (i, (labeledPoint.features(i), labeledPoint.features(i)*labeledPoint.features(i))) - } - }.reduceByKey { case(x1, x2) => - (x1._1 + x2._1, x1._2 + x2._2) - } - val xColSumsMap = xColSumSq.collectAsMap() - - val xColMean = DoubleMatrix.zeros(nfeatures, 1) - val xColSd = DoubleMatrix.zeros(nfeatures, 1) - - // Compute mean and unbiased variance using column sums - var col = 0 - while (col < nfeatures) { - xColMean.put(col, xColSumsMap(col)._1 / nexamples) - val variance = - (xColSumsMap(col)._2 - (math.pow(xColSumsMap(col)._1, 2) / nexamples)) / nexamples - xColSd.put(col, math.sqrt(variance)) - col += 1 - } - - (yMean, xColMean, xColSd) - } - /** * Returns the squared Euclidean distance between two vectors. The following formula will be used * if it does not introduce too much numerical error: @@ -144,6 +185,18 @@ object MLUtils { val sumSquaredNorm = norm1 * norm1 + norm2 * norm2 val normDiff = norm1 - norm2 var sqDist = 0.0 + /* + * The relative error is + *
      +     * EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),
      +     * 
      + * which is bounded by + *
      +     * 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).
      +     * 
      + * The bound doesn't need the inner product, so we can use it as a sufficient condition to + * check quickly whether the inner product approach is accurate. + */ val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON) if (precisionBound1 < precision) { sqDist = sumSquaredNorm - 2.0 * v1.dot(v2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala index c96c94f70eef7..ba8190b0e07e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/SVMDataGenerator.scala @@ -21,14 +21,18 @@ import scala.util.Random import org.jblas.DoubleMatrix +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint /** + * :: DeveloperApi :: * Generate sample data used for SVM. This class generates uniform random values * for the features and adds Gaussian noise with weight 0.1 to generate labels. */ +@DeveloperApi object SVMDataGenerator { def main(args: Array[String]) { @@ -58,7 +62,7 @@ object SVMDataGenerator { } val yD = new DoubleMatrix(1, x.length, x: _*).dot(trueWeights) + rnd.nextGaussian() * 0.1 val y = if (yD < 0) 0.0 else 1.0 - LabeledPoint(y, x) + LabeledPoint(y, Vectors.dense(x)) } MLUtils.saveLabeledData(data, outputPath) diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java index 073ded6f36933..c80b1134ed1b2 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java @@ -19,6 +19,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.junit.After; import org.junit.Assert; @@ -45,12 +46,12 @@ public void tearDown() { } private static final List POINTS = Arrays.asList( - new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}), - new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}), - new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}), - new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}), - new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}), - new LabeledPoint(2, new double[] {0.0, 0.0, 2.0}) + new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)), + new LabeledPoint(0, Vectors.dense(2.0, 0.0, 0.0)), + new LabeledPoint(1, Vectors.dense(0.0, 1.0, 0.0)), + new LabeledPoint(1, Vectors.dense(0.0, 2.0, 0.0)), + new LabeledPoint(2, Vectors.dense(0.0, 0.0, 1.0)), + new LabeledPoint(2, Vectors.dense(0.0, 0.0, 2.0)) ); private int validatePrediction(List points, NaiveBayesModel model) { diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java index 117e5eaa8b78e..4701a5e545020 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.mllib.classification; - import java.io.Serializable; import java.util.List; @@ -28,7 +27,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; - import org.apache.spark.mllib.regression.LabeledPoint; public class JavaSVMSuite implements Serializable { @@ -94,5 +92,4 @@ public void runSVMUsingStaticMethods() { int numAccurate = validatePrediction(validationData, model); Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0); } - } diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 2c4d795f96e4e..c6d8425ffc38d 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -19,10 +19,10 @@ import java.io.Serializable; -import com.google.common.collect.Lists; - import scala.Tuple2; +import com.google.common.collect.Lists; + import org.junit.Test; import static org.junit.Assert.*; @@ -36,7 +36,7 @@ public void denseArrayConstruction() { @Test public void sparseArrayConstruction() { - Vector v = Vectors.sparse(3, Lists.newArrayList( + Vector v = Vectors.sparse(3, Lists.>newArrayList( new Tuple2(0, 2.0), new Tuple2(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java index f44b25cd44d19..f725924a2d971 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java @@ -59,7 +59,7 @@ int validatePrediction(List validationData, LassoModel model) { @Test public void runLassoUsingConstructor() { int nPoints = 10000; - double A = 2.0; + double A = 0.0; double[] weights = {-1.5, 1.0e-2}; JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, @@ -80,7 +80,7 @@ public void runLassoUsingConstructor() { @Test public void runLassoUsingStaticMethods() { int nPoints = 10000; - double A = 2.0; + double A = 0.0; double[] weights = {-1.5, 1.0e-2}; JavaRDD testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A, diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java index 2fdd5fc8fdca6..03714ae7e4d00 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java @@ -55,30 +55,27 @@ public void tearDown() { return errorSum / validationData.size(); } - List generateRidgeData(int numPoints, int nfeatures, double eps) { + List generateRidgeData(int numPoints, int numFeatures, double std) { org.jblas.util.Random.seed(42); // Pick weights as random values distributed uniformly in [-0.5, 0.5] - DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5); - // Set first two weights to eps - w.put(0, 0, eps); - w.put(1, 0, eps); - return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps); + DoubleMatrix w = DoubleMatrix.rand(numFeatures, 1).subi(0.5); + return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, std); } @Test public void runRidgeRegressionUsingConstructor() { - int nexamples = 200; - int nfeatures = 20; - double eps = 10.0; - List data = generateRidgeData(2*nexamples, nfeatures, eps); + int numExamples = 50; + int numFeatures = 20; + List data = generateRidgeData(2*numExamples, numFeatures, 10.0); - JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); - List validationData = data.subList(nexamples, 2*nexamples); + JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples)); + List validationData = data.subList(numExamples, 2 * numExamples); RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD(); - ridgeSGDImpl.optimizer().setStepSize(1.0) - .setRegParam(0.0) - .setNumIterations(200); + ridgeSGDImpl.optimizer() + .setStepSize(1.0) + .setRegParam(0.0) + .setNumIterations(200); RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd()); double unRegularizedErr = predictionError(validationData, model); @@ -91,13 +88,12 @@ public void runRidgeRegressionUsingConstructor() { @Test public void runRidgeRegressionUsingStaticMethods() { - int nexamples = 200; - int nfeatures = 20; - double eps = 10.0; - List data = generateRidgeData(2*nexamples, nfeatures, eps); + int numExamples = 50; + int numFeatures = 20; + List data = generateRidgeData(2 * numExamples, numFeatures, 10.0); - JavaRDD testRDD = sc.parallelize(data.subList(0, nexamples)); - List validationData = data.subList(nexamples, 2*nexamples); + JavaRDD testRDD = sc.parallelize(data.subList(0, numExamples)); + List validationData = data.subList(numExamples, 2 * numExamples); RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0); double unRegularizedErr = predictionError(validationData, model); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala index 05322b024d5f6..1e03c9df820b0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.mllib.classification import scala.util.Random import scala.collection.JavaConversions._ -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers -import org.apache.spark.SparkContext +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.LocalSparkContext @@ -61,7 +60,7 @@ object LogisticRegressionSuite { if (yVal > 0) 1 else 0 } - val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i)))) + val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i))))) testData } @@ -113,7 +112,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Shoul val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42) val initialB = -1.0 - val initialWeights = Array(initialB) + val initialWeights = Vectors.dense(initialB) val testRDD = sc.parallelize(testData, 2) testRDD.cache() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 9dd6c79ee6ad8..516895d04222d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification import scala.util.Random -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.LocalSparkContext @@ -54,7 +54,7 @@ object NaiveBayesSuite { if (rnd.nextDouble() < _theta(y)(j)) 1 else 0 } - LabeledPoint(y, xi) + LabeledPoint(y, Vectors.dense(xi)) } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index bc7abb568a172..dfacbfeee6fb4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.mllib.classification import scala.util.Random import scala.collection.JavaConversions._ -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.jblas.DoubleMatrix @@ -28,6 +27,7 @@ import org.jblas.DoubleMatrix import org.apache.spark.SparkException import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.linalg.Vectors object SVMSuite { @@ -54,7 +54,7 @@ object SVMSuite { intercept + 0.01 * rnd.nextGaussian() if (yD < 0) 0.0 else 1.0 } - y.zip(x).map(p => LabeledPoint(p._1, p._2)) + y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } } @@ -110,7 +110,7 @@ class SVMSuite extends FunSuite with LocalSparkContext { val initialB = -1.0 val initialC = -1.0 - val initialWeights = Array(initialB,initialC) + val initialWeights = Vectors.dense(initialB, initialC) val testRDD = sc.parallelize(testData, 2) testRDD.cache() @@ -150,10 +150,10 @@ class SVMSuite extends FunSuite with LocalSparkContext { } intercept[SparkException] { - val model = SVMWithSGD.train(testRDDInvalid, 100) + SVMWithSGD.train(testRDDInvalid, 100) } // Turning off data validation should not throw an exception - val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid) + new SVMWithSGD().setValidateData(false).run(testRDDInvalid) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala new file mode 100644 index 0000000000000..1c9844f289fe0 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext + +class AreaUnderCurveSuite extends FunSuite with LocalSparkContext { + test("auc computation") { + val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0)) + val auc = 4.0 + assert(AreaUnderCurve.of(curve) === auc) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) == auc) + } + + test("auc of an empty curve") { + val curve = Seq.empty[(Double, Double)] + assert(AreaUnderCurve.of(curve) === 0.0) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) === 0.0) + } + + test("auc of a curve with a single point") { + val curve = Seq((1.0, 1.0)) + assert(AreaUnderCurve.of(curve) === 0.0) + val rddCurve = sc.parallelize(curve, 2) + assert(AreaUnderCurve.of(rddCurve) === 0.0) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala new file mode 100644 index 0000000000000..173fdaefab3da --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/binary/BinaryClassificationMetricsSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.evaluation.binary + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.evaluation.AreaUnderCurve + +class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext { + test("binary evaluation metrics") { + val scoreAndLabels = sc.parallelize( + Seq((0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)), 2) + val metrics = new BinaryClassificationMetrics(scoreAndLabels) + val threshold = Seq(0.8, 0.6, 0.4, 0.1) + val numTruePositives = Seq(1, 3, 3, 4) + val numFalsePositives = Seq(0, 1, 2, 3) + val numPositives = 4 + val numNegatives = 3 + val precision = numTruePositives.zip(numFalsePositives).map { case (t, f) => + t.toDouble / (t + f) + } + val recall = numTruePositives.map(t => t.toDouble / numPositives) + val fpr = numFalsePositives.map(f => f.toDouble / numNegatives) + val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0)) + val pr = recall.zip(precision) + val prCurve = Seq((0.0, 1.0)) ++ pr + val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) } + val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)} + assert(metrics.thresholds().collect().toSeq === threshold) + assert(metrics.roc().collect().toSeq === rocCurve) + assert(metrics.areaUnderROC() === AreaUnderCurve.of(rocCurve)) + assert(metrics.pr().collect().toSeq === prCurve) + assert(metrics.areaUnderPR() === AreaUnderCurve.of(prCurve)) + assert(metrics.fMeasureByThreshold().collect().toSeq === threshold.zip(f1)) + assert(metrics.fMeasureByThreshold(2.0).collect().toSeq === threshold.zip(f2)) + assert(metrics.precisionByThreshold().collect().toSeq === threshold.zip(precision)) + assert(metrics.recallByThreshold().collect().toSeq === threshold.zip(recall)) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala new file mode 100644 index 0000000000000..82d49c76ed02b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import org.scalatest.FunSuite + +import breeze.linalg.{DenseMatrix => BDM} + +class BreezeMatrixConversionSuite extends FunSuite { + test("dense matrix to breeze") { + val mat = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) + val breeze = mat.toBreeze.asInstanceOf[BDM[Double]] + assert(breeze.rows === mat.numRows) + assert(breeze.cols === mat.numCols) + assert(breeze.data.eq(mat.asInstanceOf[DenseMatrix].values), "should not copy data") + } + + test("dense breeze matrix to matrix") { + val breeze = new BDM[Double](3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) + val mat = Matrices.fromBreeze(breeze).asInstanceOf[DenseMatrix] + assert(mat.numRows === breeze.rows) + assert(mat.numCols === breeze.cols) + assert(mat.values.eq(breeze.data), "should not copy data") + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala new file mode 100644 index 0000000000000..9c66b4db9f16b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg + +import org.scalatest.FunSuite + +class MatricesSuite extends FunSuite { + test("dense matrix construction") { + val m = 3 + val n = 2 + val values = Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0) + val mat = Matrices.dense(m, n, values).asInstanceOf[DenseMatrix] + assert(mat.numRows === m) + assert(mat.numCols === n) + assert(mat.values.eq(values), "should not copy data") + assert(mat.toArray.eq(values), "toArray should not copy data") + } + + test("dense matrix construction with wrong dimension") { + intercept[RuntimeException] { + Matrices.dense(3, 2, Array(0.0, 1.0, 2.0)) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/PCASuite.scala deleted file mode 100644 index 5e5086b1bf73e..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/PCASuite.scala +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.linalg - -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -import org.apache.spark.mllib.util._ - -import org.jblas._ - -class PCASuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - val EPSILON = 1e-3 - - // Return jblas matrix from sparse matrix RDD - def getDenseMatrix(matrix: SparseMatrix) : DoubleMatrix = { - val data = matrix.data - val ret = DoubleMatrix.zeros(matrix.m, matrix.n) - matrix.data.collect().map(x => ret.put(x.i, x.j, x.mval)) - ret - } - - def assertMatrixApproximatelyEquals(a: DoubleMatrix, b: DoubleMatrix) { - assert(a.rows == b.rows && a.columns == b.columns, - "dimension mismatch: $a.rows vs $b.rows and $a.columns vs $b.columns") - for (i <- 0 until a.columns) { - val aCol = a.getColumn(i) - val bCol = b.getColumn(i) - val diff = Math.min(aCol.sub(bCol).norm1, aCol.add(bCol).norm1) - assert(diff < EPSILON, "matrix mismatch: " + diff) - } - } - - test("full rank matrix pca") { - val m = 5 - val n = 3 - val dataArr = Array.tabulate(m,n){ (a, b) => - MatrixEntry(a, b, Math.sin(a + b + a * b)) }.flatten - val data = sc.makeRDD(dataArr, 3) - val a = LAUtils.sparseToTallSkinnyDense(SparseMatrix(data, m, n)) - - val realPCAArray = Array((0,0,-0.2579), (0,1,-0.6602), (0,2,0.7054), - (1,0,-0.1448), (1,1,0.7483), (1,2,0.6474), - (2,0,0.9553), (2,1,-0.0649), (2,2,0.2886)) - val realPCA = sc.makeRDD(realPCAArray.map(x => MatrixEntry(x._1, x._2, x._3)), 3) - - val coeffs = new DoubleMatrix(new PCA().setK(n).compute(a)) - - assertMatrixApproximatelyEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs) - } - - test("sparse matrix full rank matrix pca") { - val m = 5 - val n = 3 - // the entry that gets dropped is zero to test sparse support - val dataArr = Array.tabulate(m,n){ (a, b) => - MatrixEntry(a, b, Math.sin(a + b + a * b)) }.flatten.drop(1) - val data = sc.makeRDD(dataArr, 3) - val a = LAUtils.sparseToTallSkinnyDense(SparseMatrix(data, m, n)) - - val realPCAArray = Array((0,0,-0.2579), (0,1,-0.6602), (0,2,0.7054), - (1,0,-0.1448), (1,1,0.7483), (1,2,0.6474), - (2,0,0.9553), (2,1,-0.0649), (2,2,0.2886)) - val realPCA = sc.makeRDD(realPCAArray.map(x => MatrixEntry(x._1, x._2, x._3))) - - val coeffs = new DoubleMatrix(new PCA().setK(n).compute(a)) - - assertMatrixApproximatelyEquals(getDenseMatrix(SparseMatrix(realPCA,n,n)), coeffs) - } - - test("truncated matrix pca") { - val m = 5 - val n = 3 - val dataArr = Array.tabulate(m,n){ (a, b) => - MatrixEntry(a, b, Math.sin(a + b + a * b)) }.flatten - - val data = sc.makeRDD(dataArr, 3) - val a = LAUtils.sparseToTallSkinnyDense(SparseMatrix(data, m, n)) - - val realPCAArray = Array((0,0,-0.2579), (0,1,-0.6602), - (1,0,-0.1448), (1,1,0.7483), - (2,0,0.9553), (2,1,-0.0649)) - val realPCA = sc.makeRDD(realPCAArray.map(x => MatrixEntry(x._1, x._2, x._3))) - - val k = 2 - val coeffs = new DoubleMatrix(new PCA().setK(k).compute(a)) - - assertMatrixApproximatelyEquals(getDenseMatrix(SparseMatrix(realPCA,n,k)), coeffs) - } -} - - diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala deleted file mode 100644 index 20e2b0f84be06..0000000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/SVDSuite.scala +++ /dev/null @@ -1,194 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.linalg - -import scala.util.Random - -import org.scalatest.BeforeAndAfterAll -import org.scalatest.FunSuite - -import org.jblas.{DoubleMatrix, Singular, MatrixFunctions} - -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - -import org.apache.spark.mllib.util._ - -import org.jblas._ - -class SVDSuite extends FunSuite with BeforeAndAfterAll { - @transient private var sc: SparkContext = _ - - override def beforeAll() { - sc = new SparkContext("local", "test") - } - - override def afterAll() { - sc.stop() - System.clearProperty("spark.driver.port") - } - - val EPSILON = 1e-4 - - // Return jblas matrix from sparse matrix RDD - def getDenseMatrix(matrix: SparseMatrix) : DoubleMatrix = { - val data = matrix.data - val m = matrix.m - val n = matrix.n - val ret = DoubleMatrix.zeros(m, n) - matrix.data.collect().map(x => ret.put(x.i, x.j, x.mval)) - ret - } - - def assertMatrixApproximatelyEquals(a: DoubleMatrix, b: DoubleMatrix) { - assert(a.rows == b.rows && a.columns == b.columns, - "dimension mismatch: $a.rows vs $b.rows and $a.columns vs $b.columns") - for (i <- 0 until a.columns) { - val aCol = a.getColumn(i) - val bCol = b.getColumn(i) - val diff = Math.min(aCol.sub(bCol).norm1, aCol.add(bCol).norm1) - assert(diff < EPSILON, "matrix mismatch: " + diff) - } - } - - test("full rank matrix svd") { - val m = 10 - val n = 3 - val datarr = Array.tabulate(m,n){ (a, b) => - MatrixEntry(a, b, (a + 2).toDouble * (b + 1) / (1 + a + b)) }.flatten - val data = sc.makeRDD(datarr, 3) - - val a = SparseMatrix(data, m, n) - - val decomposed = new SVD().setK(n).compute(a) - val u = decomposed.U - val v = decomposed.V - val s = decomposed.S - - val denseA = getDenseMatrix(a) - val svd = Singular.sparseSVD(denseA) - - val retu = getDenseMatrix(u) - val rets = getDenseMatrix(s) - val retv = getDenseMatrix(v) - - - // check individual decomposition - assertMatrixApproximatelyEquals(retu, svd(0)) - assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1))) - assertMatrixApproximatelyEquals(retv, svd(2)) - - // check multiplication guarantee - assertMatrixApproximatelyEquals(retu.mmul(rets).mmul(retv.transpose), denseA) - } - - test("dense full rank matrix svd") { - val m = 10 - val n = 3 - val datarr = Array.tabulate(m,n){ (a, b) => - MatrixEntry(a, b, (a + 2).toDouble * (b + 1) / (1 + a + b)) }.flatten - val data = sc.makeRDD(datarr, 3) - - val a = LAUtils.sparseToTallSkinnyDense(SparseMatrix(data, m, n)) - - val decomposed = new SVD().setK(n).setComputeU(true).compute(a) - val u = LAUtils.denseToSparse(decomposed.U) - val v = decomposed.V - val s = decomposed.S - - val denseA = getDenseMatrix(LAUtils.denseToSparse(a)) - val svd = Singular.sparseSVD(denseA) - - val retu = getDenseMatrix(u) - val rets = DoubleMatrix.diag(new DoubleMatrix(s)) - val retv = new DoubleMatrix(v) - - - // check individual decomposition - assertMatrixApproximatelyEquals(retu, svd(0)) - assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1))) - assertMatrixApproximatelyEquals(retv, svd(2)) - - // check multiplication guarantee - assertMatrixApproximatelyEquals(retu.mmul(rets).mmul(retv.transpose), denseA) - } - - test("rank one matrix svd") { - val m = 10 - val n = 3 - val data = sc.makeRDD(Array.tabulate(m, n){ (a,b) => - MatrixEntry(a, b, 1.0) }.flatten ) - val k = 1 - - val a = SparseMatrix(data, m, n) - - val decomposed = new SVD().setK(k).compute(a) - val u = decomposed.U - val s = decomposed.S - val v = decomposed.V - val retrank = s.data.collect().length - - assert(retrank == 1, "rank returned not one") - - val denseA = getDenseMatrix(a) - val svd = Singular.sparseSVD(denseA) - - val retu = getDenseMatrix(u) - val rets = getDenseMatrix(s) - val retv = getDenseMatrix(v) - - // check individual decomposition - assertMatrixApproximatelyEquals(retu, svd(0).getColumn(0)) - assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1).getRow(0))) - assertMatrixApproximatelyEquals(retv, svd(2).getColumn(0)) - - // check multiplication guarantee - assertMatrixApproximatelyEquals(retu.mmul(rets).mmul(retv.transpose), denseA) - } - - test("truncated with k") { - val m = 10 - val n = 3 - val data = sc.makeRDD(Array.tabulate(m,n){ (a, b) => - MatrixEntry(a, b, (a + 2).toDouble * (b + 1)/(1 + a + b)) }.flatten ) - val a = SparseMatrix(data, m, n) - - val k = 1 // only one svalue above this - - val decomposed = new SVD().setK(k).compute(a) - val u = decomposed.U - val s = decomposed.S - val v = decomposed.V - val retrank = s.data.collect().length - - val denseA = getDenseMatrix(a) - val svd = Singular.sparseSVD(denseA) - - val retu = getDenseMatrix(u) - val rets = getDenseMatrix(s) - val retv = getDenseMatrix(v) - - assert(retrank == 1, "rank returned not one") - - // check individual decomposition - assertMatrixApproximatelyEquals(retu, svd(0).getColumn(0)) - assertMatrixApproximatelyEquals(rets, DoubleMatrix.diag(svd(1).getRow(0))) - assertMatrixApproximatelyEquals(retv, svd(2).getColumn(0)) - } -} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala new file mode 100644 index 0000000000000..cd45438fb628f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import org.scalatest.FunSuite + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.linalg.Vectors + +class CoordinateMatrixSuite extends FunSuite with LocalSparkContext { + + val m = 5 + val n = 4 + var mat: CoordinateMatrix = _ + + override def beforeAll() { + super.beforeAll() + val entries = sc.parallelize(Seq( + (0, 0, 1.0), + (0, 1, 2.0), + (1, 1, 3.0), + (1, 2, 4.0), + (2, 2, 5.0), + (2, 3, 6.0), + (3, 0, 7.0), + (3, 3, 8.0), + (4, 1, 9.0)), 3).map { case (i, j, value) => + MatrixEntry(i, j, value) + } + mat = new CoordinateMatrix(entries) + } + + test("size") { + assert(mat.numRows() === m) + assert(mat.numCols() === n) + } + + test("empty entries") { + val entries = sc.parallelize(Seq[MatrixEntry](), 1) + val emptyMat = new CoordinateMatrix(entries) + intercept[RuntimeException] { + emptyMat.numCols() + } + intercept[RuntimeException] { + emptyMat.numRows() + } + } + + test("toBreeze") { + val expected = BDM( + (1.0, 2.0, 0.0, 0.0), + (0.0, 3.0, 4.0, 0.0), + (0.0, 0.0, 5.0, 6.0), + (7.0, 0.0, 0.0, 8.0), + (0.0, 9.0, 0.0, 0.0)) + assert(mat.toBreeze() === expected) + } + + test("toIndexedRowMatrix") { + val indexedRowMatrix = mat.toIndexedRowMatrix() + val expected = BDM( + (1.0, 2.0, 0.0, 0.0), + (0.0, 3.0, 4.0, 0.0), + (0.0, 0.0, 5.0, 6.0), + (7.0, 0.0, 0.0, 8.0), + (0.0, 9.0, 0.0, 0.0)) + assert(indexedRowMatrix.toBreeze() === expected) + } + + test("toRowMatrix") { + val rowMatrix = mat.toRowMatrix() + val rows = rowMatrix.rows.collect().toSet + val expected = Set( + Vectors.dense(1.0, 2.0, 0.0, 0.0), + Vectors.dense(0.0, 3.0, 4.0, 0.0), + Vectors.dense(0.0, 0.0, 5.0, 6.0), + Vectors.dense(7.0, 0.0, 0.0, 8.0), + Vectors.dense(0.0, 9.0, 0.0, 0.0)) + assert(rows === expected) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala new file mode 100644 index 0000000000000..f7c46f23b746d --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import org.scalatest.FunSuite + +import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Matrices, Vectors} + +class IndexedRowMatrixSuite extends FunSuite with LocalSparkContext { + + val m = 4 + val n = 3 + val data = Seq( + (0L, Vectors.dense(0.0, 1.0, 2.0)), + (1L, Vectors.dense(3.0, 4.0, 5.0)), + (3L, Vectors.dense(9.0, 0.0, 1.0)) + ).map(x => IndexedRow(x._1, x._2)) + var indexedRows: RDD[IndexedRow] = _ + + override def beforeAll() { + super.beforeAll() + indexedRows = sc.parallelize(data, 2) + } + + test("size") { + val mat1 = new IndexedRowMatrix(indexedRows) + assert(mat1.numRows() === m) + assert(mat1.numCols() === n) + + val mat2 = new IndexedRowMatrix(indexedRows, 5, 0) + assert(mat2.numRows() === 5) + assert(mat2.numCols() === n) + } + + test("empty rows") { + val rows = sc.parallelize(Seq[IndexedRow](), 1) + val mat = new IndexedRowMatrix(rows) + intercept[RuntimeException] { + mat.numRows() + } + intercept[RuntimeException] { + mat.numCols() + } + } + + test("toBreeze") { + val mat = new IndexedRowMatrix(indexedRows) + val expected = BDM( + (0.0, 1.0, 2.0), + (3.0, 4.0, 5.0), + (0.0, 0.0, 0.0), + (9.0, 0.0, 1.0)) + assert(mat.toBreeze() === expected) + } + + test("toRowMatrix") { + val idxRowMat = new IndexedRowMatrix(indexedRows) + val rowMat = idxRowMat.toRowMatrix() + assert(rowMat.numCols() === n) + assert(rowMat.numRows() === 3, "should drop empty rows") + assert(rowMat.rows.collect().toSeq === data.map(_.vector).toSeq) + } + + test("multiply a local matrix") { + val A = new IndexedRowMatrix(indexedRows) + val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) + val C = A.multiply(B) + val localA = A.toBreeze() + val localC = C.toBreeze() + val expected = localA * B.toBreeze.asInstanceOf[BDM[Double]] + assert(localC === expected) + } + + test("gram") { + val A = new IndexedRowMatrix(indexedRows) + val G = A.computeGramianMatrix() + val expected = BDM( + (90.0, 12.0, 24.0), + (12.0, 17.0, 22.0), + (24.0, 22.0, 30.0)) + assert(G.toBreeze === expected) + } + + test("svd") { + val A = new IndexedRowMatrix(indexedRows) + val svd = A.computeSVD(n, computeU = true) + assert(svd.U.isInstanceOf[IndexedRowMatrix]) + val localA = A.toBreeze() + val U = svd.U.toBreeze() + val s = svd.s.toBreeze.asInstanceOf[BDV[Double]] + val V = svd.V.toBreeze.asInstanceOf[BDM[Double]] + assert(closeToZero(U.t * U - BDM.eye[Double](n))) + assert(closeToZero(V.t * V - BDM.eye[Double](n))) + assert(closeToZero(U * brzDiag(s) * V.t - localA)) + } + + def closeToZero(G: BDM[Double]): Boolean = { + G.valuesIterator.map(math.abs).sum < 1e-6 + } +} + diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala new file mode 100644 index 0000000000000..c9f9acf4c1335 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import org.scalatest.FunSuite + +import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} + +class RowMatrixSuite extends FunSuite with LocalSparkContext { + + val m = 4 + val n = 3 + val arr = Array(0.0, 3.0, 6.0, 9.0, 1.0, 4.0, 7.0, 0.0, 2.0, 5.0, 8.0, 1.0) + val denseData = Seq( + Vectors.dense(0.0, 1.0, 2.0), + Vectors.dense(3.0, 4.0, 5.0), + Vectors.dense(6.0, 7.0, 8.0), + Vectors.dense(9.0, 0.0, 1.0) + ) + val sparseData = Seq( + Vectors.sparse(3, Seq((1, 1.0), (2, 2.0))), + Vectors.sparse(3, Seq((0, 3.0), (1, 4.0), (2, 5.0))), + Vectors.sparse(3, Seq((0, 6.0), (1, 7.0), (2, 8.0))), + Vectors.sparse(3, Seq((0, 9.0), (2, 1.0))) + ) + + val principalComponents = BDM( + (0.0, 1.0, 0.0), + (math.sqrt(2.0) / 2.0, 0.0, math.sqrt(2.0) / 2.0), + (math.sqrt(2.0) / 2.0, 0.0, - math.sqrt(2.0) / 2.0)) + + var denseMat: RowMatrix = _ + var sparseMat: RowMatrix = _ + + override def beforeAll() { + super.beforeAll() + denseMat = new RowMatrix(sc.parallelize(denseData, 2)) + sparseMat = new RowMatrix(sc.parallelize(sparseData, 2)) + } + + test("size") { + assert(denseMat.numRows() === m) + assert(denseMat.numCols() === n) + assert(sparseMat.numRows() === m) + assert(sparseMat.numCols() === n) + } + + test("empty rows") { + val rows = sc.parallelize(Seq[Vector](), 1) + val emptyMat = new RowMatrix(rows) + intercept[RuntimeException] { + emptyMat.numCols() + } + intercept[RuntimeException] { + emptyMat.numRows() + } + } + + test("toBreeze") { + val expected = BDM( + (0.0, 1.0, 2.0), + (3.0, 4.0, 5.0), + (6.0, 7.0, 8.0), + (9.0, 0.0, 1.0)) + for (mat <- Seq(denseMat, sparseMat)) { + assert(mat.toBreeze() === expected) + } + } + + test("gram") { + val expected = + Matrices.dense(n, n, Array(126.0, 54.0, 72.0, 54.0, 66.0, 78.0, 72.0, 78.0, 94.0)) + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.computeGramianMatrix() + assert(G.toBreeze === expected.toBreeze) + } + } + + test("svd of a full-rank matrix") { + for (mat <- Seq(denseMat, sparseMat)) { + val localMat = mat.toBreeze() + val (localU, localSigma, localVt) = brzSvd(localMat) + val localV: BDM[Double] = localVt.t.toDenseMatrix + for (k <- 1 to n) { + val svd = mat.computeSVD(k, computeU = true) + val U = svd.U + val s = svd.s + val V = svd.V + assert(U.numRows() === m) + assert(U.numCols() === k) + assert(s.size === k) + assert(V.numRows === n) + assert(V.numCols === k) + assertColumnEqualUpToSign(U.toBreeze(), localU, k) + assertColumnEqualUpToSign(V.toBreeze.asInstanceOf[BDM[Double]], localV, k) + assert(closeToZero(s.toBreeze.asInstanceOf[BDV[Double]] - localSigma(0 until k))) + } + val svdWithoutU = mat.computeSVD(n) + assert(svdWithoutU.U === null) + } + } + + test("svd of a low-rank matrix") { + val rows = sc.parallelize(Array.fill(4)(Vectors.dense(1.0, 1.0)), 2) + val mat = new RowMatrix(rows, 4, 2) + val svd = mat.computeSVD(2, computeU = true) + assert(svd.s.size === 1, "should not return zero singular values") + assert(svd.U.numRows() === 4) + assert(svd.U.numCols() === 1) + assert(svd.V.numRows === 2) + assert(svd.V.numCols === 1) + } + + def closeToZero(G: BDM[Double]): Boolean = { + G.valuesIterator.map(math.abs).sum < 1e-6 + } + + def closeToZero(v: BDV[Double]): Boolean = { + brzNorm(v, 1.0) < 1e-6 + } + + def assertColumnEqualUpToSign(A: BDM[Double], B: BDM[Double], k: Int) { + assert(A.rows === B.rows) + for (j <- 0 until k) { + val aj = A(::, j) + val bj = B(::, j) + assert(closeToZero(aj - bj) || closeToZero(aj + bj), + s"The $j-th columns mismatch: $aj and $bj") + } + } + + test("pca") { + for (mat <- Seq(denseMat, sparseMat); k <- 1 to n) { + val pc = denseMat.computePrincipalComponents(k) + assert(pc.numRows === n) + assert(pc.numCols === k) + assertColumnEqualUpToSign(pc.toBreeze.asInstanceOf[BDM[Double]], principalComponents, k) + } + } + + test("multiply a local matrix") { + val B = Matrices.dense(n, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) + for (mat <- Seq(denseMat, sparseMat)) { + val AB = mat.multiply(B) + assert(AB.numRows() === m) + assert(AB.numCols() === 2) + assert(AB.rows.collect().toSeq === Seq( + Vectors.dense(5.0, 14.0), + Vectors.dense(14.0, 50.0), + Vectors.dense(23.0, 86.0), + Vectors.dense(2.0, 32.0) + )) + } + } + + test("compute column summary statistics") { + for (mat <- Seq(denseMat, sparseMat)) { + val summary = mat.computeColumnSummaryStatistics() + // Run twice to make sure no internal states are changed. + for (k <- 0 to 1) { + assert(summary.mean === Vectors.dense(4.5, 3.0, 4.0), "mean mismatch") + assert(summary.variance === Vectors.dense(15.0, 10.0, 10.0), "variance mismatch") + assert(summary.count === m, "count mismatch.") + assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch") + assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch") + assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.") + } + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 631d0e2ad9cdb..c4b433499a091 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -20,13 +20,12 @@ package org.apache.spark.mllib.optimization import scala.util.Random import scala.collection.JavaConversions._ -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers -import org.apache.spark.SparkContext import org.apache.spark.mllib.regression._ import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.linalg.Vectors object GradientDescentSuite { @@ -58,8 +57,7 @@ object GradientDescentSuite { if (yVal > 0) 1 else 0 } - val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i)))) - testData + (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(x1(i)))) } } @@ -83,11 +81,11 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Array(1.0, features: _*) + label -> Vectors.dense(1.0, features.toArray: _*) } val dataRDD = sc.parallelize(data, 2).cache() - val initialWeightsWithIntercept = Array(1.0, initialWeights: _*) + val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*) val (_, loss) = GradientDescent.runMiniBatchSGD( dataRDD, @@ -113,13 +111,13 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa // Add a extra variable consisting of all 1.0's for the intercept. val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42) val data = testData.map { case LabeledPoint(label, features) => - label -> Array(1.0, features: _*) + label -> Vectors.dense(1.0, features.toArray: _*) } val dataRDD = sc.parallelize(data, 2).cache() // Prepare non-zero weights - val initialWeightsWithIntercept = Array(1.0, 0.5) + val initialWeightsWithIntercept = Vectors.dense(1.0, 0.5) val regParam0 = 0 val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD( diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala new file mode 100644 index 0000000000000..3f3b10dfff35e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.mllib.rdd.RDDFunctions._ + +class RDDFunctionsSuite extends FunSuite with LocalSparkContext { + + test("sliding") { + val data = 0 until 6 + for (numPartitions <- 1 to 8) { + val rdd = sc.parallelize(data, numPartitions) + for (windowSize <- 1 to 6) { + val sliding = rdd.sliding(windowSize).collect().map(_.toList).toList + val expected = data.sliding(windowSize).map(_.toList).toList + assert(sliding === expected) + } + assert(rdd.sliding(7).collect().isEmpty, + "Should return an empty RDD if the window size is greater than the number of items.") + } + } + + test("sliding with empty partitions") { + val data = Seq(Seq(1, 2, 3), Seq.empty[Int], Seq(4), Seq.empty[Int], Seq(5, 6, 7)) + val rdd = sc.parallelize(data, data.length).flatMap(s => s) + assert(rdd.partitions.size === data.length) + val sliding = rdd.sliding(3) + val expected = data.flatMap(x => x).sliding(3).toList + assert(sliding.collect().toList === expected) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 2cebac943e15f..6aad9eb84e13c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression import org.scalatest.FunSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} class LassoSuite extends FunSuite with LocalSparkContext { @@ -33,29 +34,33 @@ class LassoSuite extends FunSuite with LocalSparkContext { } test("Lasso local random SGD") { - val nPoints = 10000 + val nPoints = 1000 val A = 2.0 val B = -1.5 val C = 1.0e-2 - val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) - - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() + val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42) + .map { case LabeledPoint(label, features) => + LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) + } + val testRDD = sc.parallelize(testData, 2).cache() val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40) val model = ls.run(testRDD) - val weight0 = model.weights(0) val weight1 = model.weights(1) - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") + val weight2 = model.weights(2) + assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]") + assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") + assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + .map { case LabeledPoint(label, features) => + LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) + } val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. @@ -66,33 +71,39 @@ class LassoSuite extends FunSuite with LocalSparkContext { } test("Lasso local random SGD with initial weights") { - val nPoints = 10000 + val nPoints = 1000 val A = 2.0 val B = -1.5 val C = 1.0e-2 - val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42) + val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42) + .map { case LabeledPoint(label, features) => + LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) + } + val initialA = -1.0 val initialB = -1.0 val initialC = -1.0 - val initialWeights = Array(initialB,initialC) + val initialWeights = Vectors.dense(initialA, initialB, initialC) - val testRDD = sc.parallelize(testData, 2) - testRDD.cache() + val testRDD = sc.parallelize(testData, 2).cache() val ls = new LassoWithSGD() - ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20) + ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40) val model = ls.run(testRDD, initialWeights) - val weight0 = model.weights(0) val weight1 = model.weights(1) - assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") - assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") - assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]") + val weight2 = model.weights(2) + assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]") + assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]") + assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]") val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17) + .map { case LabeledPoint(label, features) => + LabeledPoint(label, Vectors.dense(1.0 +: features.toArray)) + } val validationRDD = sc.parallelize(validationData,2) // Test prediction on RDD. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 5d251bcbf35db..2f7d30708ce17 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression import org.scalatest.FunSuite +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} class LinearRegressionSuite extends FunSuite with LocalSparkContext { @@ -40,11 +41,12 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { linReg.optimizer.setNumIterations(1000).setStepSize(1.0) val model = linReg.run(testRDD) - assert(model.intercept >= 2.5 && model.intercept <= 3.5) - assert(model.weights.length === 2) - assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) - assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val weights = model.weights + assert(weights.size === 2) + assert(weights(0) >= 9.0 && weights(0) <= 11.0) + assert(weights(1) >= 9.0 && weights(1) <= 11.0) val validationData = LinearDataGenerator.generateLinearInput( 3.0, Array(10.0, 10.0), 100, 17) @@ -67,9 +69,11 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { val model = linReg.run(testRDD) assert(model.intercept === 0.0) - assert(model.weights.length === 2) - assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) - assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val weights = model.weights + assert(weights.size === 2) + assert(weights(0) >= 9.0 && weights(0) <= 11.0) + assert(weights(1) >= 9.0 && weights(1) <= 11.0) val validationData = LinearDataGenerator.generateLinearInput( 0.0, Array(10.0, 10.0), 100, 17) @@ -81,4 +85,40 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + // Test if we can correctly learn Y = 10*X1 + 10*X10000 + test("sparse linear regression without intercept") { + val denseRDD = sc.parallelize( + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42), 2) + val sparseRDD = denseRDD.map { case LabeledPoint(label, v) => + val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1)))) + LabeledPoint(label, sv) + }.cache() + val linReg = new LinearRegressionWithSGD().setIntercept(false) + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(sparseRDD) + + assert(model.intercept === 0.0) + + val weights = model.weights + assert(weights.size === 10000) + assert(weights(0) >= 9.0 && weights(0) <= 11.0) + assert(weights(9999) >= 9.0 && weights(9999) <= 11.0) + + val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17) + val sparseValidationData = validationData.map { case LabeledPoint(label, v) => + val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1)))) + LabeledPoint(label, sv) + } + val sparseValidationRDD = sc.parallelize(sparseValidationData, 2) + + // Test prediction on RDD. + validatePrediction( + model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData) + + // Test prediction on Array. + validatePrediction( + sparseValidationData.map(row => model.predict(row.features)), sparseValidationData) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index b2044ed0d8066..f66fc6ea6c1ec 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.regression -import org.jblas.DoubleMatrix import org.scalatest.FunSuite +import org.jblas.DoubleMatrix + import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} class RidgeRegressionSuite extends FunSuite with LocalSparkContext { @@ -30,22 +31,22 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext { }.reduceLeft(_ + _) / predictions.size } - test("regularization with skewed weights") { - val nexamples = 200 - val nfeatures = 20 - val eps = 10 + test("ridge regression can help avoid overfitting") { + + // For small number of examples and large variance of error distribution, + // ridge regression should give smaller generalization error that linear regression. + + val numExamples = 50 + val numFeatures = 20 org.jblas.util.Random.seed(42) // Pick weights as random values distributed uniformly in [-0.5, 0.5] - val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) - // Set first two weights to eps - w.put(0, 0, eps) - w.put(1, 0, eps) + val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5) // Use half of data for training and other half for validation - val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps) - val testData = data.take(nexamples) - val validationData = data.takeRight(nexamples) + val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0) + val testData = data.take(numExamples) + val validationData = data.takeRight(numExamples) val testRDD = sc.parallelize(testData, 2).cache() val validationRDD = sc.parallelize(validationData, 2).cache() @@ -67,7 +68,7 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext { val ridgeErr = predictionError( ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData) - // Ridge CV-error should be lower than linear regression + // Ridge validation error should be lower than linear regression. assert(ridgeErr < linearErr, "ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")") } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala new file mode 100644 index 0000000000000..350130c914f26 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -0,0 +1,426 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} +import org.apache.spark.mllib.tree.model.Filter +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.mllib.linalg.Vectors + +class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll { + + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + test("split and bin calculation") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + } + + test("split and bin calculation for categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 2, 1-> 2)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(bins.length === 2) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2) === null) + + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2) === null) + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2) === null) + + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2) === null) + } + + test("split and bin calculations for categorical variables with no sample for one category") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + + // Check splits. + + assert(splits(0)(0).feature === 0) + assert(splits(0)(0).threshold === Double.MinValue) + assert(splits(0)(0).featureType === Categorical) + assert(splits(0)(0).categories.length === 1) + assert(splits(0)(0).categories.contains(1.0)) + + assert(splits(0)(1).feature === 0) + assert(splits(0)(1).threshold === Double.MinValue) + assert(splits(0)(1).featureType === Categorical) + assert(splits(0)(1).categories.length === 2) + assert(splits(0)(1).categories.contains(1.0)) + assert(splits(0)(1).categories.contains(0.0)) + + assert(splits(0)(2).feature === 0) + assert(splits(0)(2).threshold === Double.MinValue) + assert(splits(0)(2).featureType === Categorical) + assert(splits(0)(2).categories.length === 3) + assert(splits(0)(2).categories.contains(1.0)) + assert(splits(0)(2).categories.contains(0.0)) + assert(splits(0)(2).categories.contains(2.0)) + + assert(splits(0)(3) === null) + + assert(splits(1)(0).feature === 1) + assert(splits(1)(0).threshold === Double.MinValue) + assert(splits(1)(0).featureType === Categorical) + assert(splits(1)(0).categories.length === 1) + assert(splits(1)(0).categories.contains(0.0)) + + assert(splits(1)(1).feature === 1) + assert(splits(1)(1).threshold === Double.MinValue) + assert(splits(1)(1).featureType === Categorical) + assert(splits(1)(1).categories.length === 2) + assert(splits(1)(1).categories.contains(1.0)) + assert(splits(1)(1).categories.contains(0.0)) + + assert(splits(1)(2).feature === 1) + assert(splits(1)(2).threshold === Double.MinValue) + assert(splits(1)(2).featureType === Categorical) + assert(splits(1)(2).categories.length === 3) + assert(splits(1)(2).categories.contains(1.0)) + assert(splits(1)(2).categories.contains(0.0)) + assert(splits(1)(2).categories.contains(2.0)) + + assert(splits(1)(3) === null) + + // Check bins. + + assert(bins(0)(0).category === 1.0) + assert(bins(0)(0).lowSplit.categories.length === 0) + assert(bins(0)(0).highSplit.categories.length === 1) + assert(bins(0)(0).highSplit.categories.contains(1.0)) + + assert(bins(0)(1).category === 0.0) + assert(bins(0)(1).lowSplit.categories.length === 1) + assert(bins(0)(1).lowSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.length === 2) + assert(bins(0)(1).highSplit.categories.contains(1.0)) + assert(bins(0)(1).highSplit.categories.contains(0.0)) + + assert(bins(0)(2).category === 2.0) + assert(bins(0)(2).lowSplit.categories.length === 2) + assert(bins(0)(2).lowSplit.categories.contains(1.0)) + assert(bins(0)(2).lowSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.length === 3) + assert(bins(0)(2).highSplit.categories.contains(1.0)) + assert(bins(0)(2).highSplit.categories.contains(0.0)) + assert(bins(0)(2).highSplit.categories.contains(2.0)) + + assert(bins(0)(3) === null) + + assert(bins(1)(0).category === 0.0) + assert(bins(1)(0).lowSplit.categories.length === 0) + assert(bins(1)(0).highSplit.categories.length === 1) + assert(bins(1)(0).highSplit.categories.contains(0.0)) + + assert(bins(1)(1).category === 1.0) + assert(bins(1)(1).lowSplit.categories.length === 1) + assert(bins(1)(1).lowSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.length === 2) + assert(bins(1)(1).highSplit.categories.contains(0.0)) + assert(bins(1)(1).highSplit.categories.contains(1.0)) + + assert(bins(1)(2).category === 2.0) + assert(bins(1)(2).lowSplit.categories.length === 2) + assert(bins(1)(2).lowSplit.categories.contains(0.0)) + assert(bins(1)(2).lowSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.length === 3) + assert(bins(1)(2).highSplit.categories.contains(0.0)) + assert(bins(1)(2).highSplit.categories.contains(1.0)) + assert(bins(1)(2).highSplit.categories.contains(2.0)) + + assert(bins(1)(3) === null) + } + + test("classification stump with all categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Classification, + Gini, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + + val split = bestSplits(0)._1 + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } + + test("regression stump with all categorical variables") { + val arr = DecisionTreeSuite.generateCategoricalDataPoints() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy( + Regression, + Variance, + maxDepth = 3, + maxBins = 100, + categoricalFeaturesInfo = Map(0 -> 3, 1-> 3)) + val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy) + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + + val split = bestSplits(0)._1 + assert(split.categories.length === 1) + assert(split.categories.contains(1.0)) + assert(split.featureType === Categorical) + assert(split.threshold === Double.MinValue) + + val stats = bestSplits(0)._2 + assert(stats.gain > 0) + assert(stats.predict > 0.4) + assert(stats.predict < 0.5) + assert(stats.impurity > 0.2) + } + + test("stump with fixed label 0 for Gini") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + } + + test("stump with fixed label 1 for Gini") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Gini, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) + } + + test("stump with fixed label 0 for Entropy") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 0) + } + + test("stump with fixed label 1 for Entropy") { + val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() + assert(arr.length === 1000) + val rdd = sc.parallelize(arr) + val strategy = new Strategy(Classification, Entropy, 3, 100) + val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy) + assert(splits.length === 2) + assert(splits(0).length === 99) + assert(bins.length === 2) + assert(bins(0).length === 100) + assert(splits(0).length === 99) + assert(bins(0).length === 100) + + val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0, + Array[List[Filter]](), splits, bins) + assert(bestSplits.length === 1) + assert(bestSplits(0)._1.feature === 0) + assert(bestSplits(0)._1.threshold === 10) + assert(bestSplits(0)._2.gain === 0) + assert(bestSplits(0)._2.leftImpurity === 0) + assert(bestSplits(0)._2.rightImpurity === 0) + assert(bestSplits(0)._2.predict === 1) + } +} + +object DecisionTreeSuite { + + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i)) + arr(i) = lp + } + arr + } + + def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i)) + arr(i) = lp + } + arr + } + + def generateCategoricalDataPoints(): Array[LabeledPoint] = { + val arr = new Array[LabeledPoint](1000) + for (i <- 0 until 1000){ + if (i < 600){ + arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)) + } else { + arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0)) + } + } + arr + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala new file mode 100644 index 0000000000000..ac85677f2f014 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LabelParsersSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.util + +import org.scalatest.FunSuite + +class LabelParsersSuite extends FunSuite { + test("binary label parser") { + for (parser <- Seq(BinaryLabelParser, BinaryLabelParser.getInstance())) { + assert(parser.parse("+1") === 1.0) + assert(parser.parse("1") === 1.0) + assert(parser.parse("0") === 0.0) + assert(parser.parse("-1") === 0.0) + } + } + + test("multiclass label parser") { + for (parser <- Seq(MulticlassLabelParser, MulticlassLabelParser.getInstance())) { + assert(parser.parse("0") == 0.0) + assert(parser.parse("+1") === 1.0) + assert(parser.parse("1") === 1.0) + assert(parser.parse("2") === 2.0) + assert(parser.parse("3") === 3.0) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala index 60f053b381305..812a8434784be 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala @@ -17,14 +17,19 @@ package org.apache.spark.mllib.util +import java.io.File + import org.scalatest.FunSuite import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm, squaredDistance => breezeSquaredDistance} +import com.google.common.base.Charsets +import com.google.common.io.Files +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils._ -class MLUtilsSuite extends FunSuite { +class MLUtilsSuite extends FunSuite with LocalSparkContext { test("epsilon computation") { assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.") @@ -49,4 +54,43 @@ class MLUtilsSuite extends FunSuite { assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m") } } + + test("loadLibSVMData") { + val lines = + """ + |+1 1:1.0 3:2.0 5:3.0 + |-1 + |-1 2:4.0 4:5.0 6:6.0 + """.stripMargin + val tempDir = Files.createTempDir() + val file = new File(tempDir.getPath, "part-00000") + Files.write(lines, file, Charsets.US_ASCII) + val path = tempDir.toURI.toString + + val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, BinaryLabelParser, 6).collect() + val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect() + + for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) { + assert(points.length === 3) + assert(points(0).label === 1.0) + assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0)))) + assert(points(1).label == 0.0) + assert(points(1).features == Vectors.sparse(6, Seq())) + assert(points(2).label === 0.0) + assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0)))) + } + + val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MulticlassLabelParser).collect() + assert(multiclassPoints.length === 3) + assert(multiclassPoints(0).label === 1.0) + assert(multiclassPoints(1).label === -1.0) + assert(multiclassPoints(2).label === -1.0) + + try { + file.delete() + tempDir.delete() + } catch { + case t: Throwable => + } + } } diff --git a/pom.xml b/pom.xml index 09a449d81453f..5f66cbe768592 100644 --- a/pom.xml +++ b/pom.xml @@ -21,7 +21,7 @@ org.apache apache - 13 + 14 org.apache.spark spark-parent @@ -54,11 +54,11 @@ JIRA - https://spark-project.atlassian.net/browse/SPARK + https://issues.apache.org/jira/browse/SPARK - 3.0.0 + 3.0.4 @@ -110,7 +110,7 @@ 1.6 - 2.10.3 + 2.10.4 2.10 0.13.0 org.spark-project.akka @@ -123,6 +123,10 @@ 0.94.6 0.12.0 1.3.2 + 1.2.3 + 8.1.14.v20131031 + 0.3.1 + 3.0.0 64m 512m @@ -192,22 +196,22 @@ org.eclipse.jetty jetty-util - 8.1.14.v20131031 + ${jetty.version} org.eclipse.jetty jetty-security - 8.1.14.v20131031 + ${jetty.version} org.eclipse.jetty jetty-plus - 8.1.14.v20131031 + ${jetty.version} org.eclipse.jetty jetty-server - 8.1.14.v20131031 + ${jetty.version} com.google.guava @@ -273,7 +277,7 @@ com.twitter chill_${scala.binary.version} - 0.3.1 + ${chill.version} org.ow2.asm @@ -288,7 +292,7 @@ com.twitter chill-java - 0.3.1 + ${chill.version} org.ow2.asm @@ -344,11 +348,6 @@ - - it.unimi.dsi - fastutil - 6.4.4 - colt colt @@ -373,14 +372,13 @@ org.apache.derby derby 10.4.2.0 - test net.liftweb lift-json_${scala.binary.version} 2.5.1 @@ -392,27 +390,27 @@ com.codahale.metrics metrics-core - 3.0.0 + ${codahale.metrics.version} com.codahale.metrics metrics-jvm - 3.0.0 + ${codahale.metrics.version} com.codahale.metrics metrics-json - 3.0.0 + ${codahale.metrics.version} com.codahale.metrics metrics-ganglia - 3.0.0 + ${codahale.metrics.version} com.codahale.metrics metrics-graphite - 3.0.0 + ${codahale.metrics.version} org.scala-lang @@ -576,6 +574,12 @@ + + + org.codehaus.jackson + jackson-mapper-asl + 1.8.8 + @@ -585,7 +589,7 @@ org.apache.maven.plugins maven-enforcer-plugin - 1.1.1 + 1.3.1 enforce-versions @@ -595,7 +599,7 @@ - 3.0.0 + 3.0.4 ${java.version} @@ -608,12 +612,12 @@ org.codehaus.mojo build-helper-maven-plugin - 1.7 + 1.8 net.alchim31.maven scala-maven-plugin - 3.1.5 + 3.1.6 scala-compile-first @@ -674,7 +678,7 @@ org.apache.maven.plugins maven-surefire-plugin - 2.12.4 + 2.17 true @@ -713,7 +717,7 @@ org.apache.maven.plugins maven-shade-plugin - 2.0 + 2.2 org.apache.maven.plugins @@ -810,7 +814,6 @@ org.apache.maven.plugins maven-jar-plugin - 2.4 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index e7c9c47c960fa..9cb31d70444ff 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -58,17 +58,25 @@ object MimaBuild { SparkBuild.SPARK_VERSION match { case v if v.startsWith("1.0") => Seq( - excludePackage("org.apache.spark.api.java"), - excludePackage("org.apache.spark.streaming.api.java"), - excludePackage("org.apache.spark.mllib") - ) ++ - excludeSparkClass("rdd.ClassTags") ++ - excludeSparkClass("util.XORShiftRandom") ++ - excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ - excludeSparkClass("mllib.optimization.SquaredGradient") ++ - excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ - excludeSparkClass("mllib.regression.LassoWithSGD") ++ - excludeSparkClass("mllib.regression.LinearRegressionWithSGD") + excludePackage("org.apache.spark.api.java"), + excludePackage("org.apache.spark.streaming.api.java"), + excludePackage("org.apache.spark.streaming.scheduler"), + excludePackage("org.apache.spark.mllib") + ) ++ + excludeSparkClass("rdd.ClassTags") ++ + excludeSparkClass("util.XORShiftRandom") ++ + excludeSparkClass("mllib.recommendation.MFDataGenerator") ++ + excludeSparkClass("mllib.optimization.SquaredGradient") ++ + excludeSparkClass("mllib.regression.RidgeRegressionWithSGD") ++ + excludeSparkClass("mllib.regression.LassoWithSGD") ++ + excludeSparkClass("mllib.regression.LinearRegressionWithSGD") ++ + excludeSparkClass("streaming.dstream.NetworkReceiver") ++ + excludeSparkClass("streaming.dstream.NetworkReceiver#NetworkReceiverActor") ++ + excludeSparkClass("streaming.dstream.NetworkReceiver#BlockGenerator") ++ + excludeSparkClass("streaming.dstream.NetworkReceiver#BlockGenerator#Block") ++ + excludeSparkClass("streaming.dstream.ReportError") ++ + excludeSparkClass("streaming.dstream.ReportBlock") ++ + excludeSparkClass("streaming.dstream.DStream") case _ => Seq() } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 7457ff456ade4..a6058bba3d211 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -30,7 +30,7 @@ import scala.collection.JavaConversions._ // import com.jsuereth.pgp.sbtplugin.PgpKeys._ object SparkBuild extends Build { - val SPARK_VERSION = "1.0.0-SNAPSHOT" + val SPARK_VERSION = "1.0.0-SNAPSHOT" // Hadoop version to build against. For example, "1.0.4" for Apache releases, or // "2.0.0-mr1-cdh4.2.0" for Cloudera Hadoop. Note that these variables can be set @@ -43,6 +43,8 @@ object SparkBuild extends Build { val DEFAULT_YARN = false + val DEFAULT_HIVE = false + // HBase version; set as appropriate. val HBASE_VERSION = "0.94.6" @@ -67,15 +69,17 @@ object SparkBuild extends Build { lazy val sql = Project("sql", file("sql/core"), settings = sqlCoreSettings) dependsOn(core, catalyst) - // Since hive is its own assembly, it depends on all of the modules. - lazy val hive = Project("hive", file("sql/hive"), settings = hiveSettings) dependsOn(sql, graphx, bagel, mllib, streaming, repl) + lazy val hive = Project("hive", file("sql/hive"), settings = hiveSettings) dependsOn(sql) + + lazy val maybeHive: Seq[ClasspathDependency] = if (isHiveEnabled) Seq(hive) else Seq() + lazy val maybeHiveRef: Seq[ProjectReference] = if (isHiveEnabled) Seq(hive) else Seq() lazy val streaming = Project("streaming", file("streaming"), settings = streamingSettings) dependsOn(core) lazy val mllib = Project("mllib", file("mllib"), settings = mllibSettings) dependsOn(core) lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings) - .dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeGanglia: _*) + .dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeHive: _*) dependsOn(maybeGanglia: _*) lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects") @@ -101,6 +105,11 @@ object SparkBuild extends Build { lazy val hadoopClient = if (hadoopVersion.startsWith("0.20.") || hadoopVersion == "1.0.0") "hadoop-core" else "hadoop-client" val maybeAvro = if (hadoopVersion.startsWith("0.23.") && isYarnEnabled) Seq("org.apache.avro" % "avro" % "1.7.4") else Seq() + lazy val isHiveEnabled = Properties.envOrNone("SPARK_HIVE") match { + case None => DEFAULT_HIVE + case Some(v) => v.toBoolean + } + // Include Ganglia integration if the user has enabled Ganglia // This is isolated from the normal build due to LGPL-licensed code in the library lazy val isGangliaEnabled = Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined @@ -141,18 +150,18 @@ object SparkBuild extends Build { lazy val allExternalRefs = Seq[ProjectReference](externalTwitter, externalKafka, externalFlume, externalZeromq, externalMqtt) lazy val examples = Project("examples", file("examples"), settings = examplesSettings) - .dependsOn(core, mllib, graphx, bagel, streaming, externalTwitter, hive) dependsOn(allExternal: _*) + .dependsOn(core, mllib, graphx, bagel, streaming, hive) dependsOn(allExternal: _*) // Everything except assembly, hive, tools, java8Tests and examples belong to packageProjects - lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx, catalyst, sql) ++ maybeYarnRef ++ maybeGangliaRef + lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graphx, catalyst, sql) ++ maybeYarnRef ++ maybeHiveRef ++ maybeGangliaRef lazy val allProjects = packageProjects ++ allExternalRefs ++ - Seq[ProjectReference](examples, tools, assemblyProj, hive) ++ maybeJava8Tests + Seq[ProjectReference](examples, tools, assemblyProj) ++ maybeJava8Tests def sharedSettings = Defaults.defaultSettings ++ MimaBuild.mimaSettings(file(sparkHome)) ++ Seq( organization := "org.apache.spark", version := SPARK_VERSION, - scalaVersion := "2.10.3", + scalaVersion := "2.10.4", scalacOptions := Seq("-Xmax-classfile-name", "120", "-unchecked", "-deprecation", "-target:" + SCALAC_JVM_VERSION), javacOptions := Seq("-target", JAVAC_JVM_VERSION, "-source", JAVAC_JVM_VERSION), @@ -169,6 +178,7 @@ object SparkBuild extends Build { fork := true, javaOptions in Test += "-Dspark.home=" + sparkHome, javaOptions in Test += "-Dspark.testing=1", + javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark").map { case (k,v) => s"-D$k=$v" }.toSeq, javaOptions += "-Xmx3g", // Show full stack trace and duration in test cases. @@ -190,10 +200,10 @@ object SparkBuild extends Build { "Apache Repository" at "https://repository.apache.org/content/repositories/releases", "JBoss Repository" at "https://repository.jboss.org/nexus/content/repositories/releases/", "MQTT Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/", - "Cloudera Repository" at "https://repository.cloudera.com/artifactory/cloudera-repos/", + "Cloudera Repository" at "http://repository.cloudera.com/artifactory/cloudera-repos/", // For Sonatype publishing - //"sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", - //"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/", + // "sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots", + // "sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/", // also check the local Maven repository ~/.m2 Resolver.mavenLocal ), @@ -249,10 +259,10 @@ object SparkBuild extends Build { libraryDependencies ++= Seq( "io.netty" % "netty-all" % "4.0.17.Final", - "org.eclipse.jetty" % "jetty-server" % "8.1.14.v20131031", - "org.eclipse.jetty" % "jetty-util" % "8.1.14.v20131031", - "org.eclipse.jetty" % "jetty-plus" % "8.1.14.v20131031", - "org.eclipse.jetty" % "jetty-security" % "8.1.14.v20131031", + "org.eclipse.jetty" % "jetty-server" % jettyVersion, + "org.eclipse.jetty" % "jetty-util" % jettyVersion, + "org.eclipse.jetty" % "jetty-plus" % jettyVersion, + "org.eclipse.jetty" % "jetty-security" % jettyVersion, /** Workaround for SPARK-959. Dependency used by org.eclipse.jetty. Fixed in ivy 2.3.0. */ "org.eclipse.jetty.orbit" % "javax.servlet" % "3.0.0.v201112011016" artifacts Artifact("javax.servlet", "jar", "jar"), "org.scalatest" %% "scalatest" % "1.9.1" % "test", @@ -277,16 +287,28 @@ object SparkBuild extends Build { publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn ) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings ++ ScalaStyleSettings + val akkaVersion = "2.2.3-shaded-protobuf" + val chillVersion = "0.3.1" + val codahaleMetricsVersion = "3.0.0" + val jblasVersion = "1.2.3" + val jettyVersion = "8.1.14.v20131031" + val hiveVersion = "0.12.0" + val parquetVersion = "1.3.2" val slf4jVersion = "1.7.5" val excludeNetty = ExclusionRule(organization = "org.jboss.netty") + val excludeEclipseJetty = ExclusionRule(organization = "org.eclipse.jetty") val excludeAsm = ExclusionRule(organization = "org.ow2.asm") val excludeOldAsm = ExclusionRule(organization = "asm") val excludeCommonsLogging = ExclusionRule(organization = "commons-logging") val excludeSLF4J = ExclusionRule(organization = "org.slf4j") val excludeScalap = ExclusionRule(organization = "org.scala-lang", artifact = "scalap") + val excludeHadoop = ExclusionRule(organization = "org.apache.hadoop") + val excludeCurator = ExclusionRule(organization = "org.apache.curator") + val excludePowermock = ExclusionRule(organization = "org.powermock") + - def sparkPreviousArtifact(id: String, organization: String = "org.apache.spark", + def sparkPreviousArtifact(id: String, organization: String = "org.apache.spark", version: String = "0.9.0-incubating", crossVersion: String = "2.10"): Option[sbt.ModuleID] = { val fullId = if (crossVersion.isEmpty) id else id + "_" + crossVersion Some(organization % fullId % version) // the artifact to compare binary compatibility with @@ -305,11 +327,10 @@ object SparkBuild extends Build { "commons-daemon" % "commons-daemon" % "1.0.10", // workaround for bug HADOOP-9407 "com.ning" % "compress-lzf" % "1.0.0", "org.xerial.snappy" % "snappy-java" % "1.0.5", - "org.spark-project.akka" %% "akka-remote" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty), - "org.spark-project.akka" %% "akka-slf4j" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty), - "org.spark-project.akka" %% "akka-testkit" % "2.2.3-shaded-protobuf" % "test", + "org.spark-project.akka" %% "akka-remote" % akkaVersion excludeAll(excludeNetty), + "org.spark-project.akka" %% "akka-slf4j" % akkaVersion excludeAll(excludeNetty), + "org.spark-project.akka" %% "akka-testkit" % akkaVersion % "test", "org.json4s" %% "json4s-jackson" % "3.2.6" excludeAll(excludeScalap), - "it.unimi.dsi" % "fastutil" % "6.4.4", "colt" % "colt" % "1.2.0", "org.apache.mesos" % "mesos" % "0.13.0", "commons-net" % "commons-net" % "2.2", @@ -317,12 +338,13 @@ object SparkBuild extends Build { "org.apache.derby" % "derby" % "10.4.2.0" % "test", "org.apache.hadoop" % hadoopClient % hadoopVersion excludeAll(excludeNetty, excludeAsm, excludeCommonsLogging, excludeSLF4J, excludeOldAsm), "org.apache.curator" % "curator-recipes" % "2.4.0" excludeAll(excludeNetty), - "com.codahale.metrics" % "metrics-core" % "3.0.0", - "com.codahale.metrics" % "metrics-jvm" % "3.0.0", - "com.codahale.metrics" % "metrics-json" % "3.0.0", - "com.codahale.metrics" % "metrics-graphite" % "3.0.0", - "com.twitter" %% "chill" % "0.3.1" excludeAll(excludeAsm), - "com.twitter" % "chill-java" % "0.3.1" excludeAll(excludeAsm), + "com.codahale.metrics" % "metrics-core" % codahaleMetricsVersion, + "com.codahale.metrics" % "metrics-jvm" % codahaleMetricsVersion, + "com.codahale.metrics" % "metrics-json" % codahaleMetricsVersion, + "com.codahale.metrics" % "metrics-graphite" % codahaleMetricsVersion, + "com.twitter" %% "chill" % chillVersion excludeAll(excludeAsm), + "com.twitter" % "chill-java" % chillVersion excludeAll(excludeAsm), + "org.tachyonproject" % "tachyon" % "0.4.1-thrift" excludeAll(excludeHadoop, excludeCurator, excludeEclipseJetty, excludePowermock), "com.clearspring.analytics" % "stream" % "2.5.1" ), libraryDependencies ++= maybeAvro @@ -365,7 +387,7 @@ object SparkBuild extends Build { name := "spark-graphx", previousArtifact := sparkPreviousArtifact("spark-graphx"), libraryDependencies ++= Seq( - "org.jblas" % "jblas" % "1.2.3" + "org.jblas" % "jblas" % jblasVersion ) ) @@ -378,7 +400,7 @@ object SparkBuild extends Build { name := "spark-mllib", previousArtifact := sparkPreviousArtifact("spark-mllib"), libraryDependencies ++= Seq( - "org.jblas" % "jblas" % "1.2.3", + "org.jblas" % "jblas" % jblasVersion, "org.scalanlp" %% "breeze" % "0.7" ) ) @@ -398,22 +420,20 @@ object SparkBuild extends Build { def sqlCoreSettings = sharedSettings ++ Seq( name := "spark-sql", libraryDependencies ++= Seq( - "com.twitter" % "parquet-column" % "1.3.2", - "com.twitter" % "parquet-hadoop" % "1.3.2" + "com.twitter" % "parquet-column" % parquetVersion, + "com.twitter" % "parquet-hadoop" % parquetVersion ) ) // Since we don't include hive in the main assembly this project also acts as an alternative // assembly jar. - def hiveSettings = sharedSettings ++ assemblyProjSettings ++ Seq( + def hiveSettings = sharedSettings ++ Seq( name := "spark-hive", - jarName in assembly <<= version map { v => "spark-hive-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }, - jarName in packageDependency <<= version map { v => "spark-hive-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" }, javaOptions += "-XX:MaxPermSize=1g", libraryDependencies ++= Seq( - "org.apache.hive" % "hive-metastore" % "0.12.0", - "org.apache.hive" % "hive-exec" % "0.12.0", - "org.apache.hive" % "hive-serde" % "0.12.0" + "org.apache.hive" % "hive-metastore" % hiveVersion, + "org.apache.hive" % "hive-exec" % hiveVersion, + "org.apache.hive" % "hive-serde" % hiveVersion ), // Multiple queries rely on the TestHive singleton. See comments there for more details. parallelExecution in Test := false, @@ -544,7 +564,7 @@ object SparkBuild extends Build { name := "spark-streaming-zeromq", previousArtifact := sparkPreviousArtifact("spark-streaming-zeromq"), libraryDependencies ++= Seq( - "org.spark-project.akka" %% "akka-zeromq" % "2.2.3-shaded-protobuf" excludeAll(excludeNetty) + "org.spark-project.akka" %% "akka-zeromq" % akkaVersion excludeAll(excludeNetty) ) ) diff --git a/project/plugins.sbt b/project/plugins.sbt index 5aa8a1ec2409b..d787237ddc540 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,4 +1,4 @@ -scalaVersion := "2.10.3" +scalaVersion := "2.10.4" resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala index 5a307044ba123..0142256e90fb7 100644 --- a/project/project/SparkPluginBuild.scala +++ b/project/project/SparkPluginBuild.scala @@ -32,7 +32,7 @@ object SparkPluginDef extends Build { name := "spark-style", organization := "org.apache.spark", version := sparkVersion, - scalaVersion := "2.10.3", + scalaVersion := "2.10.4", scalacOptions := Seq("-unchecked", "-deprecation"), libraryDependencies ++= Dependencies.scalaStyle ) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index bf2454fd7e38e..d8667e84fedff 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -28,7 +28,8 @@ from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer +from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ + PairDeserializer from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD @@ -257,6 +258,45 @@ def textFile(self, name, minSplits=None): return RDD(self._jsc.textFile(name, minSplits), self, UTF8Deserializer()) + def wholeTextFiles(self, path): + """ + Read a directory of text files from HDFS, a local file system + (available on all nodes), or any Hadoop-supported file system + URI. Each file is read as a single record and returned in a + key-value pair, where the key is the path of each file, the + value is the content of each file. + + For example, if you have the following files:: + + hdfs://a-hdfs-path/part-00000 + hdfs://a-hdfs-path/part-00001 + ... + hdfs://a-hdfs-path/part-nnnnn + + Do C{rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")}, + then C{rdd} contains:: + + (a-hdfs-path/part-00000, its content) + (a-hdfs-path/part-00001, its content) + ... + (a-hdfs-path/part-nnnnn, its content) + + NOTE: Small files are preferred, as each file will be loaded + fully in memory. + + >>> dirPath = os.path.join(tempdir, "files") + >>> os.mkdir(dirPath) + >>> with open(os.path.join(dirPath, "1.txt"), "w") as file1: + ... file1.write("1") + >>> with open(os.path.join(dirPath, "2.txt"), "w") as file2: + ... file2.write("2") + >>> textFiles = sc.wholeTextFiles(dirPath) + >>> sorted(textFiles.collect()) + [(u'.../1.txt', u'1'), (u'.../2.txt', u'2')] + """ + return RDD(self._jsc.wholeTextFiles(path), self, + PairDeserializer(UTF8Deserializer(), UTF8Deserializer())) + def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) return RDD(jrdd, self, input_deserializer) @@ -383,8 +423,11 @@ def _getJavaStorageLevel(self, storageLevel): raise Exception("storageLevel must be of type pyspark.StorageLevel") newStorageLevel = self._jvm.org.apache.spark.storage.StorageLevel - return newStorageLevel(storageLevel.useDisk, storageLevel.useMemory, - storageLevel.deserialized, storageLevel.replication) + return newStorageLevel(storageLevel.useDisk, + storageLevel.useMemory, + storageLevel.useOffHeap, + storageLevel.deserialized, + storageLevel.replication) def setJobGroup(self, groupId, description): """ @@ -425,7 +468,7 @@ def _test(): globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) globs['tempdir'] = tempfile.mkdtemp() atexit.register(lambda: shutil.rmtree(globs['tempdir'])) - (failure_count, test_count) = doctest.testmod(globs=globs) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/join.py b/python/pyspark/join.py index 5f4294fb1b777..6f94d26ef86a9 100644 --- a/python/pyspark/join.py +++ b/python/pyspark/join.py @@ -31,11 +31,12 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ +from pyspark.resultiterable import ResultIterable def _do_python_join(rdd, other, numPartitions, dispatch): vs = rdd.map(lambda (k, v): (k, (1, v))) ws = other.map(lambda (k, v): (k, (2, v))) - return vs.union(ws).groupByKey(numPartitions).flatMapValues(dispatch) + return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x : dispatch(x.__iter__())) def python_join(rdd, other, numPartitions): @@ -88,5 +89,5 @@ def dispatch(seq): vbuf.append(v) elif n == 2: wbuf.append(v) - return (vbuf, wbuf) + return (ResultIterable(vbuf), ResultIterable(wbuf)) return vs.union(ws).groupByKey(numPartitions).mapValues(dispatch) diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py index b420d7a7f23ba..538ff26ce7c33 100644 --- a/python/pyspark/mllib/__init__.py +++ b/python/pyspark/mllib/__init__.py @@ -19,11 +19,7 @@ Python bindings for MLlib. """ -# MLlib currently needs Python 2.7+ and NumPy 1.7+, so complain if lower - -import sys -if sys.version_info[0:2] < (2, 7): - raise Exception("MLlib requires Python 2.7+") +# MLlib currently needs and NumPy 1.7+, so complain if lower import numpy if numpy.version.version < '1.7': diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 20a0e309d1494..7ef251d24c77e 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -15,8 +15,9 @@ # limitations under the License. # -from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape +from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype from pyspark import SparkContext, RDD +import numpy as np from pyspark.serializers import Serializer import struct @@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset): return ar.copy() def _serialize_double_vector(v): - """Serialize a double vector into a mutually understood format.""" + """Serialize a double vector into a mutually understood format. + + >>> x = array([1,2,3]) + >>> y = _deserialize_double_vector(_serialize_double_vector(x)) + >>> array_equal(y, array([1.0, 2.0, 3.0])) + True + """ if type(v) != ndarray: raise TypeError("_serialize_double_vector called on a %s; " "wanted ndarray" % type(v)) + """complex is only datatype that can't be converted to float64""" + if issubdtype(v.dtype, complex): + raise TypeError("_serialize_double_vector called on a %s; " + "wanted ndarray" % type(v)) if v.dtype != float64: - raise TypeError("_serialize_double_vector called on an ndarray of %s; " - "wanted ndarray of float64" % v.dtype) + v = v.astype(float64) if v.ndim != 1: raise TypeError("_serialize_double_vector called on a %ddarray; " "wanted a 1darray" % v.ndim) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 19b90dfd6e167..d2f9cdb3f4298 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -87,18 +87,19 @@ class NaiveBayesModel(object): >>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3) >>> model = NaiveBayes.train(sc.parallelize(data)) >>> model.predict(array([0.0, 1.0])) - 0 + 0.0 >>> model.predict(array([1.0, 0.0])) - 1 + 1.0 """ - def __init__(self, pi, theta): + def __init__(self, labels, pi, theta): + self.labels = labels self.pi = pi self.theta = theta def predict(self, x): """Return the most likely class for a data vector x""" - return numpy.argmax(self.pi + dot(x, self.theta)) + return self.labels[numpy.argmax(self.pi + dot(x, self.theta))] class NaiveBayes(object): @classmethod @@ -122,7 +123,8 @@ def train(cls, data, lambda_=1.0): ans = sc._jvm.PythonMLLibAPI().trainNaiveBayes(dataBytes._jrdd, lambda_) return NaiveBayesModel( _deserialize_double_vector(ans[0]), - _deserialize_double_matrix(ans[1])) + _deserialize_double_vector(ans[1]), + _deserialize_double_matrix(ans[2])) def _test(): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 019c249699c2d..91fc7e637e2c6 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -29,7 +29,7 @@ from tempfile import NamedTemporaryFile from threading import Thread import warnings -from heapq import heappush, heappop, heappushpop +import heapq from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long @@ -38,12 +38,13 @@ from pyspark.statcounter import StatCounter from pyspark.rddsampler import RDDSampler from pyspark.storagelevel import StorageLevel +from pyspark.resultiterable import ResultIterable from py4j.java_collections import ListConverter, MapConverter - __all__ = ["RDD"] + def _extract_concise_traceback(): """ This function returns the traceback info for a callsite, returns a dict @@ -91,6 +92,73 @@ def __exit__(self, type, value, tb): if _spark_stack_depth == 0: self._context._jsc.setCallSite(None) +class MaxHeapQ(object): + """ + An implementation of MaxHeap. + >>> import pyspark.rdd + >>> heap = pyspark.rdd.MaxHeapQ(5) + >>> [heap.insert(i) for i in range(10)] + [None, None, None, None, None, None, None, None, None, None] + >>> sorted(heap.getElements()) + [0, 1, 2, 3, 4] + >>> heap = pyspark.rdd.MaxHeapQ(5) + >>> [heap.insert(i) for i in range(9, -1, -1)] + [None, None, None, None, None, None, None, None, None, None] + >>> sorted(heap.getElements()) + [0, 1, 2, 3, 4] + >>> heap = pyspark.rdd.MaxHeapQ(1) + >>> [heap.insert(i) for i in range(9, -1, -1)] + [None, None, None, None, None, None, None, None, None, None] + >>> heap.getElements() + [0] + """ + + def __init__(self, maxsize): + # we start from q[1], this makes calculating children as trivial as 2 * k + self.q = [0] + self.maxsize = maxsize + + def _swim(self, k): + while (k > 1) and (self.q[k/2] < self.q[k]): + self._swap(k, k/2) + k = k/2 + + def _swap(self, i, j): + t = self.q[i] + self.q[i] = self.q[j] + self.q[j] = t + + def _sink(self, k): + N = self.size() + while 2 * k <= N: + j = 2 * k + # Here we test if both children are greater than parent + # if not swap with larger one. + if j < N and self.q[j] < self.q[j + 1]: + j = j + 1 + if(self.q[k] > self.q[j]): + break + self._swap(k, j) + k = j + + def size(self): + return len(self.q) - 1 + + def insert(self, value): + if (self.size()) < self.maxsize: + self.q.append(value) + self._swim(self.size()) + else: + self._replaceRoot(value) + + def getElements(self): + return self.q[1:] + + def _replaceRoot(self, value): + if(self.q[1] > value): + self.q[1] = value + self._sink(1) + class RDD(object): """ A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. @@ -696,16 +764,16 @@ def top(self, num): Note: It returns the list sorted in descending order. >>> sc.parallelize([10, 4, 2, 12, 3]).top(1) [12] - >>> sc.parallelize([2, 3, 4, 5, 6]).cache().top(2) + >>> sc.parallelize([2, 3, 4, 5, 6], 2).cache().top(2) [6, 5] """ def topIterator(iterator): q = [] for k in iterator: if len(q) < num: - heappush(q, k) + heapq.heappush(q, k) else: - heappushpop(q, k) + heapq.heappushpop(q, k) yield q def merge(a, b): @@ -713,6 +781,36 @@ def merge(a, b): return sorted(self.mapPartitions(topIterator).reduce(merge), reverse=True) + def takeOrdered(self, num, key=None): + """ + Get the N elements from a RDD ordered in ascending order or as specified + by the optional key function. + + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) + [1, 2, 3, 4, 5, 6] + >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7], 2).takeOrdered(6, key=lambda x: -x) + [10, 9, 7, 6, 5, 4] + """ + + def topNKeyedElems(iterator, key_=None): + q = MaxHeapQ(num) + for k in iterator: + if key_ != None: + k = (key_(k), k) + q.insert(k) + yield q.getElements() + + def unKey(x, key_=None): + if key_ != None: + x = [i[1] for i in x] + return x + + def merge(a, b): + return next(topNKeyedElems(a + b)) + result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge) + return sorted(unKey(result, key), key=key) + + def take(self, num): """ Take the first num elements of the RDD. @@ -1021,7 +1119,7 @@ def groupByKey(self, numPartitions=None): Hash-partitions the resulting RDD with into numPartitions partitions. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) - >>> sorted(x.groupByKey().collect()) + >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) [('a', [1, 1]), ('b', [1])] """ @@ -1036,7 +1134,7 @@ def mergeCombiners(a, b): return a + b return self.combineByKey(createCombiner, mergeValue, mergeCombiners, - numPartitions) + numPartitions).mapValues(lambda x: ResultIterable(x)) # TODO: add tests def flatMapValues(self, f): @@ -1083,7 +1181,7 @@ def cogroup(self, other, numPartitions=None): >>> x = sc.parallelize([("a", 1), ("b", 4)]) >>> y = sc.parallelize([("a", 2)]) - >>> sorted(x.cogroup(y).collect()) + >>> map((lambda (x,y): (x, (list(y[0]), list(y[1])))), sorted(list(x.cogroup(y).collect()))) [('a', ([1], [2])), ('b', ([4], []))] """ return python_cogroup(self, other, numPartitions) @@ -1120,7 +1218,7 @@ def keyBy(self, f): >>> x = sc.parallelize(range(0,3)).keyBy(lambda x: x*x) >>> y = sc.parallelize(zip(range(0,5), range(0,5))) - >>> sorted(x.cogroup(y).collect()) + >>> map((lambda (x,y): (x, (list(y[0]), (list(y[1]))))), sorted(x.cogroup(y).collect())) [(0, ([0], [0])), (1, ([1], [1])), (2, ([], [2])), (3, ([], [3])), (4, ([2], [4]))] """ return self.map(lambda x: (f(x), x)) @@ -1205,11 +1303,12 @@ def getStorageLevel(self): Get the RDD's current storage level. >>> rdd1 = sc.parallelize([1,2]) >>> rdd1.getStorageLevel() - StorageLevel(False, False, False, 1) + StorageLevel(False, False, False, False, 1) """ java_storage_level = self._jrdd.getStorageLevel() storage_level = StorageLevel(java_storage_level.useDisk(), java_storage_level.useMemory(), + java_storage_level.useOffHeap(), java_storage_level.deserialized(), java_storage_level.replication()) return storage_level @@ -1219,7 +1318,6 @@ def getStorageLevel(self): # keys in the pairs. This could be an expensive operation, since those # hashes aren't retained. - class PipelinedRDD(RDD): """ Pipelined maps: diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py new file mode 100644 index 0000000000000..7f418f8d2e29a --- /dev/null +++ b/python/pyspark/resultiterable.py @@ -0,0 +1,33 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__all__ = ["ResultIterable"] + +import collections + +class ResultIterable(collections.Iterable): + """ + A special result iterable. This is used because the standard iterator can not be pickled + """ + def __init__(self, data): + self.data = data + self.index = 0 + self.maxindex = len(data) + def __iter__(self): + return iter(self.data) + def __len__(self): + return len(self.data) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 12c63f186a2b7..b253807974a2e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -64,6 +64,7 @@ from itertools import chain, izip, product import marshal import struct +import sys from pyspark import cloudpickle @@ -113,6 +114,11 @@ class FramedSerializer(Serializer): where C{length} is a 32-bit integer and data is C{length} bytes. """ + def __init__(self): + # On Python 2.6, we can't write bytearrays to streams, so we need to convert them + # to strings first. Check if the version number is that old. + self._only_write_strings = sys.version_info[0:2] <= (2, 6) + def dump_stream(self, iterator, stream): for obj in iterator: self._write_with_length(obj, stream) @@ -127,7 +133,10 @@ def load_stream(self, stream): def _write_with_length(self, obj, stream): serialized = self.dumps(obj) write_int(len(serialized), stream) - stream.write(serialized) + if self._only_write_strings: + stream.write(str(serialized)) + else: + stream.write(serialized) def _read_with_length(self, stream): length = read_int(stream) @@ -290,7 +299,7 @@ class MarshalSerializer(FramedSerializer): class UTF8Deserializer(Serializer): """ - Deserializes streams written by getBytes. + Deserializes streams written by String.getBytes. """ def loads(self, stream): diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py index 3d779faf1fa44..61613dbed8dce 100644 --- a/python/pyspark/shell.py +++ b/python/pyspark/shell.py @@ -29,7 +29,10 @@ # this is the equivalent of ADD_JARS add_files = os.environ.get("ADD_FILES").split(',') if os.environ.get("ADD_FILES") != None else None -sc = SparkContext(os.environ.get("MASTER", "local"), "PySparkShell", pyFiles=add_files) +if os.environ.get("SPARK_EXECUTOR_URI"): + SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"]) + +sc = SparkContext(os.environ.get("MASTER", "local[*]"), "PySparkShell", pyFiles=add_files) print """Welcome to ____ __ diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index c3e3a44e8e7ab..7b6660eab231b 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -25,23 +25,25 @@ class StorageLevel: Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY. """ - def __init__(self, useDisk, useMemory, deserialized, replication = 1): + def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication = 1): self.useDisk = useDisk self.useMemory = useMemory + self.useOffHeap = useOffHeap self.deserialized = deserialized self.replication = replication def __repr__(self): - return "StorageLevel(%s, %s, %s, %s)" % ( - self.useDisk, self.useMemory, self.deserialized, self.replication) + return "StorageLevel(%s, %s, %s, %s, %s)" % ( + self.useDisk, self.useMemory, self.useOffHeap, self.deserialized, self.replication) -StorageLevel.DISK_ONLY = StorageLevel(True, False, False) -StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, 2) -StorageLevel.MEMORY_ONLY = StorageLevel(False, True, True) -StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, True, 2) -StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False) -StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, 2) -StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, True) -StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, True, 2) -StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False) -StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, 2) +StorageLevel.DISK_ONLY = StorageLevel(True, False, False, False) +StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, False, 2) +StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, True) +StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, True, 2) +StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False, False) +StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, False, 2) +StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, True) +StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, True, 2) +StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False, False) +StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, False, 2) +StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) \ No newline at end of file diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index bf73800388ebf..687e85ca94d3c 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -26,21 +26,23 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkEnv import org.apache.spark.util.Utils - +import org.apache.spark.util.ParentClassLoader import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm._ import com.esotericsoftware.reflectasm.shaded.org.objectweb.asm.Opcodes._ - /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, - * used to load classes defined by the interpreter when the REPL is used + * used to load classes defined by the interpreter when the REPL is used. + * Allows the user to specify if user class path should be first */ -class ExecutorClassLoader(classUri: String, parent: ClassLoader) -extends ClassLoader(parent) { +class ExecutorClassLoader(classUri: String, parent: ClassLoader, + userClassPathFirst: Boolean) extends ClassLoader { val uri = new URI(classUri) val directory = uri.getPath + val parentLoader = new ParentClassLoader(parent) + // Hadoop FileSystem object for our URI, if it isn't using HTTP var fileSystem: FileSystem = { if (uri.getScheme() == "http") { @@ -49,8 +51,27 @@ extends ClassLoader(parent) { FileSystem.get(uri, new Configuration()) } } - + override def findClass(name: String): Class[_] = { + userClassPathFirst match { + case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) + case false => { + try { + parentLoader.loadClass(name) + } catch { + case e: ClassNotFoundException => { + val classOption = findClassLocally(name) + classOption match { + case None => throw new ClassNotFoundException(name, e) + case Some(a) => a + } + } + } + } + } + } + + def findClassLocally(name: String): Option[Class[_]] = { try { val pathInDirectory = name.replace('.', '/') + ".class" val inputStream = { @@ -68,12 +89,12 @@ extends ClassLoader(parent) { } val bytes = readAndTransformClass(name, inputStream) inputStream.close() - return defineClass(name, bytes, 0, bytes.length) + Some(defineClass(name, bytes, 0, bytes.length)) } catch { - case e: Exception => throw new ClassNotFoundException(name, e) + case e: Exception => None } } - + def readAndTransformClass(name: String, in: InputStream): Array[Byte] = { if (name.startsWith("line") && name.endsWith("$iw$")) { // Class seems to be an interpreter "wrapper" object storing a val or var. diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 9b1da195002c2..beb40e87024bd 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -39,6 +39,7 @@ import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} import org.apache.spark.Logging import org.apache.spark.SparkConf import org.apache.spark.SparkContext +import org.apache.spark.util.Utils /** The Scala interactive shell. It provides a read-eval-print loop * around the Interpreter class. @@ -130,7 +131,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, def history = in.history /** The context class loader at the time this object was created */ - protected val originalClassLoader = Thread.currentThread.getContextClassLoader + protected val originalClassLoader = Utils.getContextOrSparkClassLoader // classpath entries added via :cp var addedClasspath: String = "" @@ -177,7 +178,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, override lazy val formatting = new Formatting { def prompt = SparkILoop.this.prompt } - override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader) + override protected def parentClassLoader = SparkHelper.explicitParentLoader(settings).getOrElse(classOf[SparkILoop].getClassLoader) } /** Create a new interpreter. */ @@ -871,7 +872,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } val u: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - val m = u.runtimeMirror(getClass.getClassLoader) + val m = u.runtimeMirror(Utils.getSparkClassLoader) private def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] = u.TypeTag[T]( m, @@ -963,7 +964,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, case Some(m) => m case None => { val prop = System.getenv("MASTER") - if (prop != null) prop else "local" + if (prop != null) prop else "local[*]" } } master diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala index 3ebf288130fb6..910b31d209e13 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala @@ -116,14 +116,14 @@ trait SparkILoopInit { } } - def initializeSpark() { + def initializeSpark() { intp.beQuietDuring { command(""" @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext(); """) command("import org.apache.spark.SparkContext._") } - echo("Spark context available as sc.") + echo("Spark context available as sc.") } // code to be executed only after the interpreter is initialized diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala index 8f61a5e835044..419796b68b113 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -187,7 +187,7 @@ trait SparkImports { if (currentImps contains imv) addWrapper() val objName = req.lineRep.readPath val valName = "$VAL" + newValId(); - + if(!code.toString.endsWith(".`" + imv + "`;\n")) { // Which means already imported code.append("val " + valName + " = " + objName + ".INSTANCE;\n") code.append("import " + valName + req.accessPath + ".`" + imv + "`;\n") diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala b/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala index 946e71039088d..0db26c3407dff 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala @@ -7,8 +7,10 @@ package org.apache.spark.repl +import scala.reflect.io.{Path, File} import scala.tools.nsc._ import scala.tools.nsc.interpreter._ +import scala.tools.nsc.interpreter.session.JLineHistory.JLineFileHistory import scala.tools.jline.console.ConsoleReader import scala.tools.jline.console.completer._ @@ -25,7 +27,7 @@ class SparkJLineReader(_completion: => Completion) extends InteractiveReader { val consoleReader = new JLineConsoleReader() lazy val completion = _completion - lazy val history: JLineHistory = JLineHistory() + lazy val history: JLineHistory = new SparkJLineHistory private def term = consoleReader.getTerminal() def reset() = term.reset() @@ -78,3 +80,11 @@ class SparkJLineReader(_completion: => Completion) extends InteractiveReader { def readOneLine(prompt: String) = consoleReader readLine prompt def readOneKey(prompt: String) = consoleReader readOneKey prompt } + +/** Changes the default history file to not collide with the scala repl's. */ +class SparkJLineHistory extends JLineFileHistory { + import Properties.userHome + + def defaultFileName = ".spark_history" + override protected lazy val historyFile = File(Path(userHome) / defaultFileName) +} diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala new file mode 100644 index 0000000000000..336df988a1b7f --- /dev/null +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.repl + +import java.io.File +import java.net.URLClassLoader + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import com.google.common.io.Files + +import org.apache.spark.TestUtils + +class ExecutorClassLoaderSuite extends FunSuite with BeforeAndAfterAll { + + val childClassNames = List("ReplFakeClass1", "ReplFakeClass2") + val parentClassNames = List("ReplFakeClass1", "ReplFakeClass2", "ReplFakeClass3") + val tempDir1 = Files.createTempDir() + val tempDir2 = Files.createTempDir() + val url1 = "file://" + tempDir1 + val urls2 = List(tempDir2.toURI.toURL).toArray + + override def beforeAll() { + childClassNames.foreach(TestUtils.createCompiledClass(_, tempDir1, "1")) + parentClassNames.foreach(TestUtils.createCompiledClass(_, tempDir2, "2")) + } + + test("child first") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(url1, parentLoader, true) + val fakeClass = classLoader.loadClass("ReplFakeClass2").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "1") + } + + test("parent first") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(url1, parentLoader, false) + val fakeClass = classLoader.loadClass("ReplFakeClass1").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "2") + } + + test("child first can fall back") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(url1, parentLoader, true) + val fakeClass = classLoader.loadClass("ReplFakeClass3").newInstance() + val fakeClassVersion = fakeClass.toString + assert(fakeClassVersion === "2") + } + + test("child first can fail") { + val parentLoader = new URLClassLoader(urls2, null) + val classLoader = new ExecutorClassLoader(url1, parentLoader, true) + intercept[java.lang.ClassNotFoundException] { + classLoader.loadClass("ReplFakeClassDoesNotExist").newInstance() + } + } + +} diff --git a/sbin/start-history-server.sh b/sbin/start-history-server.sh new file mode 100755 index 0000000000000..4a90c68763b68 --- /dev/null +++ b/sbin/start-history-server.sh @@ -0,0 +1,37 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the history server on the machine this script is executed on. +# +# Usage: start-history-server.sh [] +# Example: ./start-history-server.sh --dir /tmp/spark-events --port 18080 +# + +sbin=`dirname "$0"` +sbin=`cd "$sbin"; pwd` + +if [ $# -lt 1 ]; then + echo "Usage: ./start-history-server.sh " + echo "Example: ./start-history-server.sh /tmp/spark-events" + exit +fi + +LOG_DIR=$1 + +"$sbin"/spark-daemon.sh start org.apache.spark.deploy.history.HistoryServer 1 --dir "$LOG_DIR" diff --git a/sbin/stop-history-server.sh b/sbin/stop-history-server.sh new file mode 100755 index 0000000000000..c0034ad641cbe --- /dev/null +++ b/sbin/stop-history-server.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Stops the history server on the machine this script is executed on. + +sbin=`dirname "$0"` +sbin=`cd "$sbin"; pwd` + +"$sbin"/spark-daemon.sh stop org.apache.spark.deploy.history.HistoryServer 1 diff --git a/sql/README.md b/sql/README.md index 4192fecb92fb0..14d5555f0c713 100644 --- a/sql/README.md +++ b/sql/README.md @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.TestHive._ -Welcome to Scala version 2.10.3 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). +Welcome to Scala version 2.10.4 (Java HotSpot(TM) 64-Bit Server VM, Java 1.7.0_45). Type in expressions to have them evaluated. Type :help for more information. diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 0edce55a93338..9d5c6a857bb00 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -44,6 +44,10 @@ + + org.scala-lang + scala-reflect + org.apache.spark spark-core_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 976dda8d7e59a..446d0e0bd7f54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.LocalRelation @@ -43,15 +45,26 @@ object ScalaReflection { val params = t.member("": TermName).asMethod.paramss StructType( params.head.map(p => StructField(p.name.toString, schemaFor(p.typeSignature), true))) + // Need to decide if we actually need a special type here. + case t if t <:< typeOf[Array[Byte]] => BinaryType + case t if t <:< typeOf[Array[_]] => + sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t ArrayType(schemaFor(elementType)) + case t if t <:< typeOf[Map[_,_]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + MapType(schemaFor(keyType), schemaFor(valueType)) case t if t <:< typeOf[String] => StringType + case t if t <:< typeOf[Timestamp] => TimestampType + case t if t <:< typeOf[BigDecimal] => DecimalType case t if t <:< definitions.IntTpe => IntegerType case t if t <:< definitions.LongTpe => LongType case t if t <:< definitions.DoubleTpe => DoubleType + case t if t <:< definitions.FloatTpe => FloatType case t if t <:< definitions.ShortTpe => ShortType case t if t <:< definitions.ByteTpe => ByteType + case t if t <:< definitions.BooleanTpe => BooleanType } implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 0c851c2ee2183..5b6aea81cb7d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -106,6 +106,8 @@ class SqlParser extends StandardTokenParsers { protected val IF = Keyword("IF") protected val IN = Keyword("IN") protected val INNER = Keyword("INNER") + protected val INSERT = Keyword("INSERT") + protected val INTO = Keyword("INTO") protected val IS = Keyword("IS") protected val JOIN = Keyword("JOIN") protected val LEFT = Keyword("LEFT") @@ -114,6 +116,7 @@ class SqlParser extends StandardTokenParsers { protected val NULL = Keyword("NULL") protected val ON = Keyword("ON") protected val OR = Keyword("OR") + protected val OVERWRITE = Keyword("OVERWRITE") protected val LIKE = Keyword("LIKE") protected val RLIKE = Keyword("RLIKE") protected val REGEXP = Keyword("REGEXP") @@ -162,7 +165,7 @@ class SqlParser extends StandardTokenParsers { select * ( UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) + ) | insert protected lazy val select: Parser[LogicalPlan] = SELECT ~> opt(DISTINCT) ~ projections ~ @@ -181,10 +184,17 @@ class SqlParser extends StandardTokenParsers { val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) val withHaving = h.map(h => Filter(h, withDistinct)).getOrElse(withDistinct) val withOrder = o.map(o => Sort(o, withHaving)).getOrElse(withHaving) - val withLimit = l.map { l => StopAfter(l, withOrder) }.getOrElse(withOrder) + val withLimit = l.map { l => Limit(l, withOrder) }.getOrElse(withOrder) withLimit } + protected lazy val insert: Parser[LogicalPlan] = + INSERT ~> opt(OVERWRITE) ~ inTo ~ select <~ opt(";") ^^ { + case o ~ r ~ s => + val overwrite: Boolean = o.getOrElse("") == "OVERWRITE" + InsertIntoTable(r, Map[String, Option[String]](), s, overwrite) + } + protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",") protected lazy val projection: Parser[Expression] = @@ -195,6 +205,8 @@ class SqlParser extends StandardTokenParsers { protected lazy val from: Parser[LogicalPlan] = FROM ~> relations + protected lazy val inTo: Parser[LogicalPlan] = INTO ~> relation + // Based very loosely on the MySQL Grammar. // http://dev.mysql.com/doc/refman/5.0/en/join.html protected lazy val relations: Parser[LogicalPlan] = @@ -207,7 +219,7 @@ class SqlParser extends StandardTokenParsers { protected lazy val relationFactor: Parser[LogicalPlan] = ident ~ (opt(AS) ~> opt(ident)) ^^ { - case ident ~ alias => UnresolvedRelation(alias, ident) + case tableName ~ alias => UnresolvedRelation(None, tableName, alias) } | "(" ~> query ~ ")" ~ opt(AS) ~ ident ^^ { case s ~ _ ~ _ ~ a => Subquery(a, s) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index e09182dd8d5df..f30b5d816703a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -31,18 +31,33 @@ trait Catalog { alias: Option[String] = None): LogicalPlan def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit + + def unregisterTable(databaseName: Option[String], tableName: String): Unit + + def unregisterAllTables(): Unit } class SimpleCatalog extends Catalog { val tables = new mutable.HashMap[String, LogicalPlan]() - def registerTable(databaseName: Option[String],tableName: String, plan: LogicalPlan): Unit = { + override def registerTable( + databaseName: Option[String], + tableName: String, + plan: LogicalPlan): Unit = { tables += ((tableName, plan)) } - def dropTable(tableName: String) = tables -= tableName + override def unregisterTable( + databaseName: Option[String], + tableName: String) = { + tables -= tableName + } - def lookupRelation( + override def unregisterAllTables() = { + tables.clear() + } + + override def lookupRelation( databaseName: Option[String], tableName: String, alias: Option[String] = None): LogicalPlan = { @@ -87,6 +102,14 @@ trait OverrideCatalog extends Catalog { plan: LogicalPlan): Unit = { overrides.put((databaseName, tableName), plan) } + + override def unregisterTable(databaseName: Option[String], tableName: String): Unit = { + overrides.remove((databaseName, tableName)) + } + + override def unregisterAllTables(): Unit = { + overrides.clear() + } } /** @@ -104,4 +127,10 @@ object EmptyCatalog extends Catalog { def registerTable(databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = { throw new UnsupportedOperationException } + + def unregisterTable(databaseName: Option[String], tableName: String): Unit = { + throw new UnsupportedOperationException + } + + override def unregisterAllTables(): Unit = {} } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 41e9bcef3cd7f..d629172a7426e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.{errors, trees} -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.BaseRelation import org.apache.spark.sql.catalyst.trees.TreeNode @@ -36,7 +37,7 @@ case class UnresolvedRelation( databaseName: Option[String], tableName: String, alias: Option[String] = None) extends BaseRelation { - def output = Nil + override def output = Nil override lazy val resolved = false } @@ -44,26 +45,33 @@ case class UnresolvedRelation( * Holds the name of an attribute that has yet to be resolved. */ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNode[Expression] { - def exprId = throw new UnresolvedException(this, "exprId") - def dataType = throw new UnresolvedException(this, "dataType") - def nullable = throw new UnresolvedException(this, "nullable") - def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def exprId = throw new UnresolvedException(this, "exprId") + override def dataType = throw new UnresolvedException(this, "dataType") + override def nullable = throw new UnresolvedException(this, "nullable") + override def qualifiers = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - def newInstance = this - def withQualifiers(newQualifiers: Seq[String]) = this + override def newInstance = this + override def withQualifiers(newQualifiers: Seq[String]) = this + + // Unresolved attributes are transient at compile time and don't get evaluated during execution. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") override def toString: String = s"'$name" } case class UnresolvedFunction(name: String, children: Seq[Expression]) extends Expression { - def exprId = throw new UnresolvedException(this, "exprId") - def dataType = throw new UnresolvedException(this, "dataType") + override def dataType = throw new UnresolvedException(this, "dataType") override def foldable = throw new UnresolvedException(this, "foldable") - def nullable = throw new UnresolvedException(this, "nullable") - def qualifiers = throw new UnresolvedException(this, "qualifiers") - def references = children.flatMap(_.references).toSet + override def nullable = throw new UnresolvedException(this, "nullable") + override def references = children.flatMap(_.references).toSet override lazy val resolved = false + + // Unresolved functions are transient at compile time and don't get evaluated during execution. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def toString = s"'$name(${children.mkString(",")})" } @@ -79,15 +87,15 @@ case class Star( mapFunction: Attribute => Expression = identity[Attribute]) extends Attribute with trees.LeafNode[Expression] { - def name = throw new UnresolvedException(this, "exprId") - def exprId = throw new UnresolvedException(this, "exprId") - def dataType = throw new UnresolvedException(this, "dataType") - def nullable = throw new UnresolvedException(this, "nullable") - def qualifiers = throw new UnresolvedException(this, "qualifiers") + override def name = throw new UnresolvedException(this, "exprId") + override def exprId = throw new UnresolvedException(this, "exprId") + override def dataType = throw new UnresolvedException(this, "dataType") + override def nullable = throw new UnresolvedException(this, "nullable") + override def qualifiers = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - def newInstance = this - def withQualifiers(newQualifiers: Seq[String]) = this + override def newInstance = this + override def withQualifiers(newQualifiers: Seq[String]) = this def expand(input: Seq[Attribute]): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { @@ -104,5 +112,9 @@ case class Star( mappedAttributes } + // Star gets expanded at runtime so we never evaluate a Star. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def toString = table.map(_ + ".").getOrElse("") + "*" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 44abe671c07a4..987befe8e22ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import java.sql.Timestamp + import scala.language.implicitConversions import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -68,10 +70,11 @@ package object dsl { def > (other: Expression) = GreaterThan(expr, other) def >= (other: Expression) = GreaterThanOrEqual(expr, other) def === (other: Expression) = Equals(expr, other) - def != (other: Expression) = Not(Equals(expr, other)) + def !== (other: Expression) = Not(Equals(expr, other)) def like(other: Expression) = Like(expr, other) def rlike(other: Expression) = RLike(expr, other) + def cast(to: DataType) = Cast(expr, to) def asc = SortOrder(expr, Ascending) def desc = SortOrder(expr, Descending) @@ -84,17 +87,24 @@ package object dsl { def expr = e } + implicit def booleanToLiteral(b: Boolean) = Literal(b) + implicit def byteToLiteral(b: Byte) = Literal(b) + implicit def shortToLiteral(s: Short) = Literal(s) implicit def intToLiteral(i: Int) = Literal(i) implicit def longToLiteral(l: Long) = Literal(l) implicit def floatToLiteral(f: Float) = Literal(f) implicit def doubleToLiteral(d: Double) = Literal(d) implicit def stringToLiteral(s: String) = Literal(s) + implicit def decimalToLiteral(d: BigDecimal) = Literal(d) + implicit def timestampToLiteral(t: Timestamp) = Literal(t) + implicit def binaryToLiteral(a: Array[Byte]) = Literal(a) implicit def symbolToUnresolvedAttribute(s: Symbol) = analysis.UnresolvedAttribute(s.name) implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s = sym.name } + // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { - def expr: Expression = Literal(s) + override def expr: Expression = Literal(s) def attr = analysis.UnresolvedAttribute(s) } @@ -103,11 +113,38 @@ package object dsl { def expr = attr def attr = analysis.UnresolvedAttribute(s) - /** Creates a new typed attributes of type int */ + /** Creates a new AttributeReference of type boolean */ + def boolean = AttributeReference(s, BooleanType, nullable = false)() + + /** Creates a new AttributeReference of type byte */ + def byte = AttributeReference(s, ByteType, nullable = false)() + + /** Creates a new AttributeReference of type short */ + def short = AttributeReference(s, ShortType, nullable = false)() + + /** Creates a new AttributeReference of type int */ def int = AttributeReference(s, IntegerType, nullable = false)() - /** Creates a new typed attributes of type string */ + /** Creates a new AttributeReference of type long */ + def long = AttributeReference(s, LongType, nullable = false)() + + /** Creates a new AttributeReference of type float */ + def float = AttributeReference(s, FloatType, nullable = false)() + + /** Creates a new AttributeReference of type double */ + def double = AttributeReference(s, DoubleType, nullable = false)() + + /** Creates a new AttributeReference of type string */ def string = AttributeReference(s, StringType, nullable = false)() + + /** Creates a new AttributeReference of type decimal */ + def decimal = AttributeReference(s, DecimalType, nullable = false)() + + /** Creates a new AttributeReference of type timestamp */ + def timestamp = AttributeReference(s, TimestampType, nullable = false)() + + /** Creates a new AttributeReference of type binary */ + def binary = AttributeReference(s, BinaryType, nullable = false)() } implicit class DslAttribute(a: AttributeReference) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index f70e80b7f27f2..4ebf6c4584b94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -45,14 +45,20 @@ case class BoundReference(ordinal: Int, baseReference: Attribute) override def toString = s"$baseReference:$ordinal" - override def apply(input: Row): Any = input(ordinal) + override def eval(input: Row): Any = input(ordinal) } +/** + * Used to denote operators that do their own binding of attributes internally. + */ +trait NoBind { self: trees.TreeNode[_] => } + class BindReferences[TreeNode <: QueryPlan[TreeNode]] extends Rule[TreeNode] { import BindReferences._ def apply(plan: TreeNode): TreeNode = { plan.transform { + case n: NoBind => n.asInstanceOf[TreeNode] case leafNode if leafNode.children.isEmpty => leafNode case unaryNode if unaryNode.children.size == 1 => unaryNode.transformExpressions { case e => bindReference(e, unaryNode.children.head.output) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c26fc3d0f305f..1f3fab09e9566 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.types._ /** Cast the child expression to the target data type. */ @@ -27,51 +29,175 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { type EvaluatedType = Any - lazy val castingFunction: Any => Any = (child.dataType, dataType) match { - case (BinaryType, StringType) => a: Any => new String(a.asInstanceOf[Array[Byte]]) - case (StringType, BinaryType) => a: Any => a.asInstanceOf[String].getBytes - case (_, StringType) => a: Any => a.toString - case (StringType, IntegerType) => a: Any => castOrNull(a, _.toInt) - case (StringType, DoubleType) => a: Any => castOrNull(a, _.toDouble) - case (StringType, FloatType) => a: Any => castOrNull(a, _.toFloat) - case (StringType, LongType) => a: Any => castOrNull(a, _.toLong) - case (StringType, ShortType) => a: Any => castOrNull(a, _.toShort) - case (StringType, ByteType) => a: Any => castOrNull(a, _.toByte) - case (StringType, DecimalType) => a: Any => castOrNull(a, BigDecimal(_)) - case (BooleanType, ByteType) => { - case null => null - case true => 1.toByte - case false => 0.toByte - } - case (dt, IntegerType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a) - case (dt, DoubleType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a) - case (dt, FloatType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toFloat(a) - case (dt, LongType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toLong(a) - case (dt, ShortType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toShort - case (dt, ByteType) => - a: Any => dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toInt(a).toByte - case (dt, DecimalType) => - a: Any => - BigDecimal(dt.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].toDouble(a)) - } - - @inline - protected def castOrNull[A](a: Any, f: String => A) = - try f(a.asInstanceOf[String]) catch { - case _: java.lang.NumberFormatException => null - } + def nullOrCast[T](a: Any, func: T => Any): Any = if(a == null) { + null + } else { + func(a.asInstanceOf[T]) + } + + // UDFToString + def castToString: Any => Any = child.dataType match { + case BinaryType => nullOrCast[Array[Byte]](_, new String(_, "UTF-8")) + case _ => nullOrCast[Any](_, _.toString) + } + + // BinaryConverter + def castToBinary: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, _.getBytes("UTF-8")) + } + + // UDFToBoolean + def castToBoolean: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, _.length() != 0) + case TimestampType => nullOrCast[Timestamp](_, b => {(b.getTime() != 0 || b.getNanos() != 0)}) + case LongType => nullOrCast[Long](_, _ != 0) + case IntegerType => nullOrCast[Int](_, _ != 0) + case ShortType => nullOrCast[Short](_, _ != 0) + case ByteType => nullOrCast[Byte](_, _ != 0) + case DecimalType => nullOrCast[BigDecimal](_, _ != 0) + case DoubleType => nullOrCast[Double](_, _ != 0) + case FloatType => nullOrCast[Float](_, _ != 0) + } + + // TimestampConverter + def castToTimestamp: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => { + // Throw away extra if more than 9 decimal places + val periodIdx = s.indexOf("."); + var n = s + if (periodIdx != -1) { + if (n.length() - periodIdx > 9) { + n = n.substring(0, periodIdx + 10) + } + } + try Timestamp.valueOf(n) catch { case _: java.lang.IllegalArgumentException => null} + }) + case BooleanType => nullOrCast[Boolean](_, b => new Timestamp((if(b) 1 else 0) * 1000)) + case LongType => nullOrCast[Long](_, l => new Timestamp(l * 1000)) + case IntegerType => nullOrCast[Int](_, i => new Timestamp(i * 1000)) + case ShortType => nullOrCast[Short](_, s => new Timestamp(s * 1000)) + case ByteType => nullOrCast[Byte](_, b => new Timestamp(b * 1000)) + // TimestampWritable.decimalToTimestamp + case DecimalType => nullOrCast[BigDecimal](_, d => decimalToTimestamp(d)) + // TimestampWritable.doubleToTimestamp + case DoubleType => nullOrCast[Double](_, d => decimalToTimestamp(d)) + // TimestampWritable.floatToTimestamp + case FloatType => nullOrCast[Float](_, f => decimalToTimestamp(f)) + } + + private def decimalToTimestamp(d: BigDecimal) = { + val seconds = d.longValue() + val bd = (d - seconds) * 1000000000 + val nanos = bd.intValue() + + // Convert to millis + val millis = seconds * 1000 + val t = new Timestamp(millis) + + // remaining fractional portion as nanos + t.setNanos(nanos) + t + } + + // Timestamp to long, converting milliseconds to seconds + private def timestampToLong(ts: Timestamp) = ts.getTime / 1000 + + private def timestampToDouble(ts: Timestamp) = { + // First part is the seconds since the beginning of time, followed by nanosecs. + ts.getTime / 1000 + ts.getNanos.toDouble / 1000000000 + } + + def castToLong: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toLong catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t)) + case DecimalType => nullOrCast[BigDecimal](_, _.toLong) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b) + } + + def castToInt: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toInt catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toInt) + case DecimalType => nullOrCast[BigDecimal](_, _.toInt) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b) + } + + def castToShort: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toShort catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toShort) + case DecimalType => nullOrCast[BigDecimal](_, _.toShort) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort + } + + def castToByte: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toByte catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToLong(t).toByte) + case DecimalType => nullOrCast[BigDecimal](_, _.toByte) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte + } + + def castToDecimal: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try BigDecimal(s.toDouble) catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) BigDecimal(1) else BigDecimal(0)) + case TimestampType => + // Note that we lose precision here. + nullOrCast[Timestamp](_, t => BigDecimal(timestampToDouble(t))) + case x: NumericType => b => BigDecimal(x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)) + } + + def castToDouble: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toDouble catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t)) + case DecimalType => nullOrCast[BigDecimal](_, _.toDouble) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b) + } + + def castToFloat: Any => Any = child.dataType match { + case StringType => nullOrCast[String](_, s => try s.toFloat catch { + case _: NumberFormatException => null + }) + case BooleanType => nullOrCast[Boolean](_, b => if(b) 1 else 0) + case TimestampType => nullOrCast[Timestamp](_, t => timestampToDouble(t).toFloat) + case DecimalType => nullOrCast[BigDecimal](_, _.toFloat) + case x: NumericType => b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) + } + + def cast: Any => Any = dataType match { + case StringType => castToString + case BinaryType => castToBinary + case DecimalType => castToDecimal + case TimestampType => castToTimestamp + case BooleanType => castToBoolean + case ByteType => castToByte + case ShortType => castToShort + case IntegerType => castToInt + case FloatType => castToFloat + case LongType => castToLong + case DoubleType => castToDouble + } - override def apply(input: Row): Any = { - val evaluated = child.apply(input) + override def eval(input: Row): Any = { + val evaluated = child.eval(input) if (evaluated == null) { null } else { - castingFunction(evaluated) + cast(evaluated) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 81fd160e00ca1..dd9332ada80dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType} +import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType} abstract class Expression extends TreeNode[Expression] { self: Product => @@ -50,8 +50,7 @@ abstract class Expression extends TreeNode[Expression] { def references: Set[Attribute] /** Returns the result of evaluating this expression on a given input Row */ - def apply(input: Row = null): EvaluatedType = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + def eval(input: Row = null): EvaluatedType /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -73,7 +72,7 @@ abstract class Expression extends TreeNode[Expression] { */ @inline def n1(e: Expression, i: Row, f: ((Numeric[Any], Any) => Any)): Any = { - val evalE = e.apply(i) + val evalE = e.eval(i) if (evalE == null) { null } else { @@ -86,6 +85,11 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed + * to be in the same data type, and also the return type. + * Either one of the expressions result is null, the evaluation result should be null. + */ @inline protected final def n2( i: Row, @@ -97,11 +101,11 @@ abstract class Expression extends TreeNode[Expression] { throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") } - val evalE1 = e1.apply(i) + val evalE1 = e1.eval(i) if(evalE1 == null) { null } else { - val evalE2 = e2.apply(i) + val evalE2 = e2.eval(i) if (evalE2 == null) { null } else { @@ -115,6 +119,11 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Evaluation helper function for 2 Fractional children expressions. Those expressions are + * supposed to be in the same data type, and also the return type. + * Either one of the expressions result is null, the evaluation result should be null. + */ @inline protected final def f2( i: Row, @@ -125,11 +134,11 @@ abstract class Expression extends TreeNode[Expression] { throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") } - val evalE1 = e1.apply(i: Row) + val evalE1 = e1.eval(i: Row) if(evalE1 == null) { null } else { - val evalE2 = e2.apply(i: Row) + val evalE2 = e2.eval(i: Row) if (evalE2 == null) { null } else { @@ -143,6 +152,11 @@ abstract class Expression extends TreeNode[Expression] { } } + /** + * Evaluation helper function for 2 Integral children expressions. Those expressions are + * supposed to be in the same data type, and also the return type. + * Either one of the expressions result is null, the evaluation result should be null. + */ @inline protected final def i2( i: Row, @@ -153,11 +167,11 @@ abstract class Expression extends TreeNode[Expression] { throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") } - val evalE1 = e1.apply(i) + val evalE1 = e1.eval(i) if(evalE1 == null) { null } else { - val evalE2 = e2.apply(i) + val evalE2 = e2.eval(i) if (evalE2 == null) { null } else { @@ -170,6 +184,43 @@ abstract class Expression extends TreeNode[Expression] { } } } + + /** + * Evaluation helper function for 2 Comparable children expressions. Those expressions are + * supposed to be in the same data type, and the return type should be Integer: + * Negative value: 1st argument less than 2nd argument + * Zero: 1st argument equals 2nd argument + * Positive value: 1st argument greater than 2nd argument + * + * Either one of the expressions result is null, the evaluation result should be null. + */ + @inline + protected final def c2( + i: Row, + e1: Expression, + e2: Expression, + f: ((Ordering[Any], Any, Any) => Any)): Any = { + if (e1.dataType != e2.dataType) { + throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}") + } + + val evalE1 = e1.eval(i) + if(evalE1 == null) { + null + } else { + val evalE2 = e2.eval(i) + if (evalE2 == null) { + null + } else { + e1.dataType match { + case i: NativeType => + f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean]( + i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType]) + case other => sys.error(s"Type $other does not support ordered operations") + } + } + } + } } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { @@ -179,7 +230,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def foldable = left.foldable && right.foldable - def references = left.references ++ right.references + override def references = left.references ++ right.references override def toString = s"($left $symbol $right)" } @@ -191,5 +242,5 @@ abstract class LeafExpression extends Expression with trees.LeafNode[Expression] abstract class UnaryExpression extends Expression with trees.UnaryNode[Expression] { self: Product => - def references = child.references + override def references = child.references } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 38542d3fc7290..c9b7cea6a3e5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -27,11 +27,12 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { this(expressions.map(BindReferences.bindReference(_, inputSchema))) protected val exprArray = expressions.toArray + def apply(input: Row): Row = { - val outputArray = new Array[Any](exprArray.size) + val outputArray = new Array[Any](exprArray.length) var i = 0 - while (i < exprArray.size) { - outputArray(i) = exprArray(i).apply(input) + while (i < exprArray.length) { + outputArray(i) = exprArray(i).eval(input) i += 1 } new GenericRow(outputArray) @@ -57,8 +58,8 @@ case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) def apply(input: Row): Row = { var i = 0 - while (i < exprArray.size) { - mutableRow(i) = exprArray(i).apply(input) + while (i < exprArray.length) { + mutableRow(i) = exprArray(i).eval(input) i += 1 } mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala index 0bde621602944..38f836f0a1a0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Rand.scala @@ -17,11 +17,18 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.Random import org.apache.spark.sql.catalyst.types.DoubleType + case object Rand extends LeafExpression { - def dataType = DoubleType - def nullable = false - def references = Set.empty + override def dataType = DoubleType + override def nullable = false + override def references = Set.empty + + private[this] lazy val rand = new Random + + override def eval(input: Row = null) = rand.nextDouble().asInstanceOf[EvaluatedType] + override def toString = "RAND()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 6f939e6c41f6b..77b5429bad432 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -19,6 +19,21 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types.NativeType +object Row { + /** + * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val pairs = sql("SELECT key, value FROM src").rdd.map { + * case Row(key: Int, value: String) => + * key -> value + * } + * }}} + */ + def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) +} + /** * Represents one row of output from a relational operator. Allows both generic access by ordinal, * which will incur boxing overhead for primitives, as well as native primitive access. @@ -75,7 +90,7 @@ trait MutableRow extends Row { def setString(ordinal: Int, value: String) /** - * EXPERIMENTAL + * Experimental * * Returns a mutable string builder for the specified column. A given row should return the * result of any mutations made to the returned buffer next time getString is called for the same @@ -197,8 +212,8 @@ class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { var i = 0 while (i < ordering.size) { val order = ordering(i) - val left = order.child.apply(a) - val right = order.child.apply(b) + val left = order.child.eval(a) + val right = order.child.eval(b) if (left == null && right == null) { // Both null, continue looking. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index f53d8504b083f..5e089f7618e0a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,13 +27,13 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi def references = children.flatMap(_.references).toSet def nullable = true - override def apply(input: Row): Any = { + override def eval(input: Row): Any = { children.size match { - case 1 => function.asInstanceOf[(Any) => Any](children(0).apply(input)) + case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => function.asInstanceOf[(Any, Any) => Any]( - children(0).apply(input), - children(1).apply(input)) + children(0).eval(input), + children(1).eval(input)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index d5d93778f4b8d..08b2f11d20f5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.errors.TreeNodeException + abstract sealed class SortDirection case object Ascending extends SortDirection case object Descending extends SortDirection @@ -26,7 +28,12 @@ case object Descending extends SortDirection * transformations over expression will descend into its child. */ case class SortOrder(child: Expression, direction: SortDirection) extends UnaryExpression { - def dataType = child.dataType - def nullable = child.nullable + override def dataType = child.dataType + override def nullable = child.nullable + + // SortOrder itself is never evaluated. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def toString = s"$child ${if (direction == Ascending) "ASC" else "DESC"}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index 9828d0b9bd8b2..e787c59e75723 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -30,7 +30,7 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression { def references = children.toSet def dataType = DynamicType - override def apply(input: Row): DynamicRow = input match { + override def eval(input: Row): DynamicRow = input match { // Avoid copy for generic rows. case g: GenericRow => new DynamicRow(children, g.values) case otherRowType => new DynamicRow(children, otherRowType.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 7303b155cae3d..b152f95f96c70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.errors.TreeNodeException abstract class AggregateExpression extends Expression { self: Product => @@ -27,7 +28,14 @@ abstract class AggregateExpression extends Expression { * Creates a new instance that can be used to compute this aggregate expression for a group * of input rows/ */ - def newInstance: AggregateFunction + def newInstance(): AggregateFunction + + /** + * [[AggregateExpression.eval]] should never be invoked because [[AggregateExpression]]'s are + * replaced with a physical aggregate operator at runtime. + */ + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } /** @@ -43,7 +51,7 @@ case class SplitEvaluation( partialEvaluations: Seq[NamedExpression]) /** - * An [[AggregateExpression]] that can be partially computed without seeing all relevent tuples. + * An [[AggregateExpression]] that can be partially computed without seeing all relevant tuples. * These partial evaluations can then be combined to compute the actual answer. */ abstract class PartialAggregate extends AggregateExpression { @@ -63,48 +71,48 @@ abstract class AggregateFunction extends AggregateExpression with Serializable with trees.LeafNode[Expression] { self: Product => - type EvaluatedType = Any + override type EvaluatedType = Any /** Base should return the generic aggregate expression that this function is computing */ val base: AggregateExpression - def references = base.references - def nullable = base.nullable - def dataType = base.dataType + override def references = base.references + override def nullable = base.nullable + override def dataType = base.dataType def update(input: Row): Unit - override def apply(input: Row): Any + override def eval(input: Row): Any // Do we really need this? - def newInstance = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) + override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray) } case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - def references = child.references - def nullable = false - def dataType = IntegerType + override def references = child.references + override def nullable = false + override def dataType = IntegerType override def toString = s"COUNT($child)" - def asPartial: SplitEvaluation = { + override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil) } - override def newInstance = new CountFunction(child, this) + override def newInstance()= new CountFunction(child, this) } case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression { - def children = expressions - def references = expressions.flatMap(_.references).toSet - def nullable = false - def dataType = IntegerType + override def children = expressions + override def references = expressions.flatMap(_.references).toSet + override def nullable = false + override def dataType = IntegerType override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})" - override def newInstance = new CountDistinctFunction(expressions, this) + override def newInstance()= new CountDistinctFunction(expressions, this) } case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - def references = child.references - def nullable = false - def dataType = DoubleType + override def references = child.references + override def nullable = false + override def dataType = DoubleType override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { @@ -118,13 +126,13 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN partialCount :: partialSum :: Nil) } - override def newInstance = new AverageFunction(child, this) + override def newInstance()= new AverageFunction(child, this) } case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - def references = child.references - def nullable = false - def dataType = child.dataType + override def references = child.references + override def nullable = false + override def dataType = child.dataType override def toString = s"SUM($child)" override def asPartial: SplitEvaluation = { @@ -134,24 +142,24 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ partialSum :: Nil) } - override def newInstance = new SumFunction(child, this) + override def newInstance()= new SumFunction(child, this) } case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - def references = child.references - def nullable = false - def dataType = child.dataType + override def references = child.references + override def nullable = false + override def dataType = child.dataType override def toString = s"SUM(DISTINCT $child)" - override def newInstance = new SumDistinctFunction(child, this) + override def newInstance()= new SumDistinctFunction(child, this) } case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - def references = child.references - def nullable = child.nullable - def dataType = child.dataType + override def references = child.references + override def nullable = child.nullable + override def dataType = child.dataType override def toString = s"FIRST($child)" override def asPartial: SplitEvaluation = { @@ -160,7 +168,7 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod First(partialFirst.toAttribute), partialFirst :: Nil) } - override def newInstance = new FirstFunction(child, this) + override def newInstance()= new FirstFunction(child, this) } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -169,17 +177,15 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) def this() = this(null, null) // Required for serialization. private var count: Long = _ - private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(EmptyRow)) + private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow)) private val sumAsDouble = Cast(sum, DoubleType) - - private val addFunction = Add(sum, expr) - override def apply(input: Row): Any = - sumAsDouble.apply(EmptyRow).asInstanceOf[Double] / count.toDouble + override def eval(input: Row): Any = + sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble - def update(input: Row): Unit = { + override def update(input: Row): Unit = { count += 1 sum.update(addFunction, input) } @@ -190,28 +196,28 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag var count: Int = _ - def update(input: Row): Unit = { - val evaluatedExpr = expr.map(_.apply(input)) + override def update(input: Row): Unit = { + val evaluatedExpr = expr.map(_.eval(input)) if (evaluatedExpr.map(_ != null).reduceLeft(_ || _)) { count += 1 } } - override def apply(input: Row): Any = count + override def eval(input: Row): Any = count } case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(null)) + private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null)) private val addFunction = Add(sum, expr) - def update(input: Row): Unit = { + override def update(input: Row): Unit = { sum.update(addFunction, input) } - override def apply(input: Row): Any = sum.apply(null) + override def eval(input: Row): Any = sum.eval(null) } case class SumDistinctFunction(expr: Expression, base: AggregateExpression) @@ -219,16 +225,16 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) def this() = this(null, null) // Required for serialization. - val seen = new scala.collection.mutable.HashSet[Any]() + private val seen = new scala.collection.mutable.HashSet[Any]() - def update(input: Row): Unit = { - val evaluatedExpr = expr.apply(input) + override def update(input: Row): Unit = { + val evaluatedExpr = expr.eval(input) if (evaluatedExpr != null) { seen += evaluatedExpr } } - override def apply(input: Row): Any = + override def eval(input: Row): Any = seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) } @@ -239,14 +245,14 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio val seen = new scala.collection.mutable.HashSet[Any]() - def update(input: Row): Unit = { - val evaluatedExpr = expr.map(_.apply(input)) + override def update(input: Row): Unit = { + val evaluatedExpr = expr.map(_.eval(input)) if (evaluatedExpr.map(_ != null).reduceLeft(_ && _)) { seen += evaluatedExpr } } - override def apply(input: Row): Any = seen.size + override def eval(input: Row): Any = seen.size } case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -254,11 +260,11 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag var result: Any = null - def update(input: Row): Unit = { + override def update(input: Row): Unit = { if (result == null) { - result = expr.apply(input) + result = expr.eval(input) } } - override def apply(input: Row): Any = result + override def eval(input: Row): Any = result } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index fba056e7c07e3..c79c1847cedf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -28,7 +28,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression { def nullable = child.nullable override def toString = s"-$child" - override def apply(input: Row): Any = { + override def eval(input: Row): Any = { n1(child, input, _.negate(_)) } } @@ -55,25 +55,25 @@ abstract class BinaryArithmetic extends BinaryExpression { case class Add(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "+" - override def apply(input: Row): Any = n2(input, left, right, _.plus(_, _)) + override def eval(input: Row): Any = n2(input, left, right, _.plus(_, _)) } case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "-" - override def apply(input: Row): Any = n2(input, left, right, _.minus(_, _)) + override def eval(input: Row): Any = n2(input, left, right, _.minus(_, _)) } case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "*" - override def apply(input: Row): Any = n2(input, left, right, _.times(_, _)) + override def eval(input: Row): Any = n2(input, left, right, _.times(_, _)) } case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "/" - override def apply(input: Row): Any = dataType match { + override def eval(input: Row): Any = dataType match { case _: FractionalType => f2(input, left, right, _.div(_, _)) case _: IntegralType => i2(input, left , right, _.quot(_, _)) } @@ -83,5 +83,5 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { def symbol = "%" - override def apply(input: Row): Any = i2(input, left, right, _.rem(_, _)) + override def eval(input: Row): Any = i2(input, left, right, _.rem(_, _)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index ab96618d73df7..c947155cb701c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -39,10 +39,10 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { override def toString = s"$child[$ordinal]" - override def apply(input: Row): Any = { + override def eval(input: Row): Any = { if (child.dataType.isInstanceOf[ArrayType]) { - val baseValue = child.apply(input).asInstanceOf[Seq[_]] - val o = ordinal.apply(input).asInstanceOf[Int] + val baseValue = child.eval(input).asInstanceOf[Seq[_]] + val o = ordinal.eval(input).asInstanceOf[Int] if (baseValue == null) { null } else if (o >= baseValue.size || o < 0) { @@ -51,8 +51,8 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { baseValue(o) } } else { - val baseValue = child.apply(input).asInstanceOf[Map[Any, _]] - val key = ordinal.apply(input) + val baseValue = child.eval(input).asInstanceOf[Map[Any, _]] + val key = ordinal.eval(input) if (baseValue == null) { null } else { @@ -85,8 +85,8 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[StructType] - override def apply(input: Row): Any = { - val baseValue = child.apply(input).asInstanceOf[Row] + override def eval(input: Row): Any = { + val baseValue = child.eval(input).asInstanceOf[Row] if (baseValue == null) null else baseValue(ordinal) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e9b491b10a5f2..dd78614754e12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -35,17 +35,17 @@ import org.apache.spark.sql.catalyst.types._ * requested. The attributes produced by this function will be automatically copied anytime rules * result in changes to the Generator or its children. */ -abstract class Generator extends Expression with (Row => TraversableOnce[Row]) { +abstract class Generator extends Expression { self: Product => - type EvaluatedType = TraversableOnce[Row] + override type EvaluatedType = TraversableOnce[Row] - lazy val dataType = + override lazy val dataType = ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable)))) - def nullable = false + override def nullable = false - def references = children.flatMap(_.references).toSet + override def references = children.flatMap(_.references).toSet /** * Should be overridden by specific generators. Called only once for each instance to ensure @@ -63,7 +63,7 @@ abstract class Generator extends Expression with (Row => TraversableOnce[Row]) { } /** Should be implemented by child classes to perform specific Generators. */ - def apply(input: Row): TraversableOnce[Row] + override def eval(input: Row): TraversableOnce[Row] /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */ override def makeCopy(newArgs: Array[AnyRef]): this.type = { @@ -83,7 +83,7 @@ case class Explode(attributeNames: Seq[String], child: Expression) child.resolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - lazy val elementTypes = child.dataType match { + private lazy val elementTypes = child.dataType match { case ArrayType(et) => et :: Nil case MapType(kt,vt) => kt :: vt :: Nil } @@ -100,13 +100,13 @@ case class Explode(attributeNames: Seq[String], child: Expression) } } - override def apply(input: Row): TraversableOnce[Row] = { + override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { case ArrayType(_) => - val inputArray = child.apply(input).asInstanceOf[Seq[Any]] + val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) case MapType(_, _) => - val inputMap = child.apply(input).asInstanceOf[Map[Any,Any]] + val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index b82a12e0f754e..e15e16d633365 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.types._ object Literal { @@ -29,6 +31,9 @@ object Literal { case s: Short => Literal(s, ShortType) case s: String => Literal(s, StringType) case b: Boolean => Literal(b, BooleanType) + case d: BigDecimal => Literal(d, DecimalType) + case t: Timestamp => Literal(t, TimestampType) + case a: Array[Byte] => Literal(a, BinaryType) case null => Literal(null, NullType) } } @@ -52,7 +57,7 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression { override def toString = if (value != null) value.toString else "null" type EvaluatedType = Any - override def apply(input: Row):Any = value + override def eval(input: Row):Any = value } // TODO: Specialize @@ -64,8 +69,8 @@ case class MutableLiteral(var value: Any, nullable: Boolean = true) extends Leaf def references = Set.empty def update(expression: Expression, input: Row) = { - value = expression.apply(input) + value = expression.eval(input) } - override def apply(input: Row) = value + override def eval(input: Row) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 69c8bed309c18..a8145c37c20fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ object NamedExpression { @@ -58,9 +59,9 @@ abstract class Attribute extends NamedExpression { def withQualifiers(newQualifiers: Seq[String]): Attribute - def references = Set(this) def toAttribute = this def newInstance: Attribute + override def references = Set(this) } /** @@ -77,15 +78,15 @@ case class Alias(child: Expression, name: String) (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) extends NamedExpression with trees.UnaryNode[Expression] { - type EvaluatedType = Any + override type EvaluatedType = Any - override def apply(input: Row) = child.apply(input) + override def eval(input: Row) = child.eval(input) - def dataType = child.dataType - def nullable = child.nullable - def references = child.references + override def dataType = child.dataType + override def nullable = child.nullable + override def references = child.references - def toAttribute = { + override def toAttribute = { if (resolved) { AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) } else { @@ -127,7 +128,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea h } - def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + override def newInstance = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -143,7 +144,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea /** * Returns a copy of this [[AttributeReference]] with new qualifiers. */ - def withQualifiers(newQualifiers: Seq[String]) = { + override def withQualifiers(newQualifiers: Seq[String]) = { if (newQualifiers == qualifiers) { this } else { @@ -151,5 +152,9 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea } } + // Unresolved attributes are transient at compile time and don't get evaluated during execution. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + override def toString: String = s"$name#${exprId.id}$typeSuffix" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 5a47768dcb4a1..ce6d99c911ab3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -41,11 +41,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression { throw new UnresolvedException(this, "Coalesce cannot have children of different types.") } - override def apply(input: Row): Any = { + override def eval(input: Row): Any = { var i = 0 var result: Any = null while(i < children.size && result == null) { - result = children(i).apply(input) + result = children(i).eval(input) i += 1 } result @@ -57,8 +57,8 @@ case class IsNull(child: Expression) extends Predicate with trees.UnaryNode[Expr override def foldable = child.foldable def nullable = false - override def apply(input: Row): Any = { - child.apply(input) == null + override def eval(input: Row): Any = { + child.eval(input) == null } } @@ -68,7 +68,7 @@ case class IsNotNull(child: Expression) extends Predicate with trees.UnaryNode[E def nullable = false override def toString = s"IS NOT NULL $child" - override def apply(input: Row): Any = { - child.apply(input) != null + override def eval(input: Row): Any = { + child.eval(input) != null } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 02fedd16b8d4b..da5b2cf5b0362 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.types.{BooleanType, StringType} +import org.apache.spark.sql.catalyst.types.{BooleanType, StringType, TimestampType} object InterpretedPredicate { def apply(expression: Expression): (Row => Boolean) = { - (r: Row) => expression.apply(r).asInstanceOf[Boolean] + (r: Row) => expression.eval(r).asInstanceOf[Boolean] } } @@ -53,8 +54,8 @@ case class Not(child: Expression) extends Predicate with trees.UnaryNode[Express def nullable = child.nullable override def toString = s"NOT $child" - override def apply(input: Row): Any = { - child.apply(input) match { + override def eval(input: Row): Any = { + child.eval(input) match { case null => null case b: Boolean => !b } @@ -70,18 +71,18 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate { def nullable = true // TODO: Figure out correct nullability semantics of IN. override def toString = s"$value IN ${list.mkString("(", ",", ")")}" - override def apply(input: Row): Any = { - val evaluatedValue = value.apply(input) - list.exists(e => e.apply(input) == evaluatedValue) + override def eval(input: Row): Any = { + val evaluatedValue = value.eval(input) + list.exists(e => e.eval(input) == evaluatedValue) } } case class And(left: Expression, right: Expression) extends BinaryPredicate { def symbol = "&&" - override def apply(input: Row): Any = { - val l = left.apply(input) - val r = right.apply(input) + override def eval(input: Row): Any = { + val l = left.eval(input) + val r = right.eval(input) if (l == false || r == false) { false } else if (l == null || r == null ) { @@ -95,9 +96,9 @@ case class And(left: Expression, right: Expression) extends BinaryPredicate { case class Or(left: Expression, right: Expression) extends BinaryPredicate { def symbol = "||" - override def apply(input: Row): Any = { - val l = left.apply(input) - val r = right.apply(input) + override def eval(input: Row): Any = { + val l = left.eval(input) + val r = right.eval(input) if (l == true || r == true) { true } else if (l == null || r == null) { @@ -114,79 +115,31 @@ abstract class BinaryComparison extends BinaryPredicate { case class Equals(left: Expression, right: Expression) extends BinaryComparison { def symbol = "=" - override def apply(input: Row): Any = { - val l = left.apply(input) - val r = right.apply(input) + override def eval(input: Row): Any = { + val l = left.eval(input) + val r = right.eval(input) if (l == null || r == null) null else l == r } } case class LessThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] < r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.lt(_, _)) - } - } + override def eval(input: Row): Any = c2(input, left, right, _.lt(_, _)) } case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { def symbol = "<=" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] <= r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.lteq(_, _)) - } - } + override def eval(input: Row): Any = c2(input, left, right, _.lteq(_, _)) } case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { def symbol = ">" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] > r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.gt(_, _)) - } - } + override def eval(input: Row): Any = c2(input, left, right, _.gt(_, _)) } case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { def symbol = ">=" - override def apply(input: Row): Any = { - if (left.dataType == StringType && right.dataType == StringType) { - val l = left.apply(input) - val r = right.apply(input) - if(l == null || r == null) { - null - } else { - l.asInstanceOf[String] >= r.asInstanceOf[String] - } - } else { - n2(input, left, right, _.gteq(_, _)) - } - } + override def eval(input: Row): Any = c2(input, left, right, _.gteq(_, _)) } case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -206,11 +159,11 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } type EvaluatedType = Any - override def apply(input: Row): Any = { - if (predicate(input).asInstanceOf[Boolean]) { - trueValue.apply(input) + override def eval(input: Row): Any = { + if (predicate.eval(input).asInstanceOf[Boolean]) { + trueValue.eval(input) } else { - falseValue.apply(input) + falseValue.eval(input) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 42b7a9b125b7a..ddc16ce87b895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -22,27 +22,25 @@ import java.util.regex.Pattern import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.catalyst.types.StringType import org.apache.spark.sql.catalyst.types.BooleanType -import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.errors.`package`.TreeNodeException trait StringRegexExpression { self: BinaryExpression => type EvaluatedType = Any - + def escape(v: String): String def matches(regex: Pattern, str: String): Boolean - + def nullable: Boolean = true def dataType: DataType = BooleanType - - // try cache the pattern for Literal + + // try cache the pattern for Literal private lazy val cache: Pattern = right match { case x @ Literal(value: String, StringType) => compile(value) case _ => null } - + protected def compile(str: String): Pattern = if(str == null) { null } else { @@ -51,13 +49,13 @@ trait StringRegexExpression { } protected def pattern(str: String) = if(cache == null) compile(str) else cache - - override def apply(input: Row): Any = { - val l = left.apply(input) - if(l == null) { + + override def eval(input: Row): Any = { + val l = left.eval(input) + if (l == null) { null } else { - val r = right.apply(input) + val r = right.eval(input) if(r == null) { null } else { @@ -75,11 +73,11 @@ trait StringRegexExpression { /** * Simple RegEx pattern matching function */ -case class Like(left: Expression, right: Expression) +case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - + def symbol = "LIKE" - + // replace the _ with .{1} exactly match 1 time of any character // replace the % with .*, match 0 or more times with any character override def escape(v: String) = { @@ -100,19 +98,19 @@ case class Like(left: Expression, right: Expression) sb.append(Pattern.quote(Character.toString(n))); } } - + i += 1 } - + sb.toString() } - + override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() } -case class RLike(left: Expression, right: Expression) +case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { - + def symbol = "RLIKE" override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3dd6818029bcf..c0a09a16ac98d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -33,7 +33,56 @@ object Optimizer extends RuleExecutor[LogicalPlan] { Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, - PushPredicateThroughInnerJoin) :: Nil + PushPredicateThroughInnerJoin, + ColumnPruning) :: Nil +} + +/** + * Attempts to eliminate the reading of unneeded columns from the query plan using the following + * transformations: + * + * - Inserting Projections beneath the following operators: + * - Aggregate + * - Project <- Join + * - Collapse adjacent projections, performing alias substitution. + */ +object ColumnPruning extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => + // Project away references that are not needed to calculate the required aggregates. + a.copy(child = Project(a.references.toSeq, child)) + + case Project(projectList, Join(left, right, joinType, condition)) => + // Collect the list of off references required either above or to evaluate the condition. + val allReferences: Set[Attribute] = + projectList.flatMap(_.references).toSet ++ condition.map(_.references).getOrElse(Set.empty) + /** Applies a projection when the child is producing unnecessary attributes */ + def prunedChild(c: LogicalPlan) = + if ((allReferences.filter(c.outputSet.contains) -- c.outputSet).nonEmpty) { + Project(allReferences.filter(c.outputSet.contains).toSeq, c) + } else { + c + } + + Project(projectList, Join(prunedChild(left), prunedChild(right), joinType, condition)) + + case Project(projectList1, Project(projectList2, child)) => + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliasMap = projectList2.collect { + case a @ Alias(e, _) => (a.toAttribute: Expression, a) + }.toMap + + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a if aliasMap.contains(a) => aliasMap(a) + }).asInstanceOf[Seq[NamedExpression]] + + Project(substitutedProjection, child) + } } /** @@ -45,7 +94,7 @@ object ConstantFolding extends Rule[LogicalPlan] { case q: LogicalPlan => q transformExpressionsDown { // Skip redundant folding of literals. case l: Literal => l - case e if e.foldable => Literal(e.apply(null), e.dataType) + case e if e.foldable => Literal(e.eval(null), e.dataType) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 9d16189deedfe..397473e178867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -127,10 +127,10 @@ case class Aggregate( extends UnaryNode { def output = aggregateExpressions.map(_.toAttribute) - def references = child.references + def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet } -case class StopAfter(limit: Expression, child: LogicalPlan) extends UnaryNode { +case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode { def output = child.output def references = limit.references } @@ -162,6 +162,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { a.nullable)( a.exprId, a.qualifiers) + case other => other } def references = Set.empty diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 8893744eb2e7a..ffb3a92f8f340 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder} +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions.{Expression, Row, SortOrder} import org.apache.spark.sql.catalyst.types.IntegerType /** @@ -139,12 +140,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) extends Expression with Partitioning { - def children = expressions - def references = expressions.flatMap(_.references).toSet - def nullable = false - def dataType = IntegerType + override def children = expressions + override def references = expressions.flatMap(_.references).toSet + override def nullable = false + override def dataType = IntegerType - lazy val clusteringSet = expressions.toSet + private[this] lazy val clusteringSet = expressions.toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true @@ -158,6 +159,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case h: HashPartitioning if h == this => true case _ => false } + + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } /** @@ -168,17 +172,20 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) * partition. * - Each partition will have a `min` and `max` row, relative to the given ordering. All rows * that are in between `min` and `max` in this `ordering` will reside in this partition. + * + * This class extends expression primarily so that transformations over expression will descend + * into its child. */ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) extends Expression with Partitioning { - def children = ordering - def references = ordering.flatMap(_.references).toSet - def nullable = false - def dataType = IntegerType + override def children = ordering + override def references = ordering.flatMap(_.references).toSet + override def nullable = false + override def dataType = IntegerType - lazy val clusteringSet = ordering.map(_.child).toSet + private[this] lazy val clusteringSet = ordering.map(_.child).toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true @@ -195,4 +202,7 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case r: RangePartitioning if r == this => true case _ => false } + + override def eval(input: Row): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 7a45d1a1b8195..da34bd3a21503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.types +import java.sql.Timestamp + import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.catalyst.expressions.Expression @@ -51,6 +53,16 @@ case object BooleanType extends NativeType { val ordering = implicitly[Ordering[JvmType]] } +case object TimestampType extends NativeType { + type JvmType = Timestamp + + @transient lazy val tag = typeTag[JvmType] + + val ordering = new Ordering[JvmType] { + def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) + } +} + abstract class NumericType extends NativeType { // Unfortunately we can't get this implicitly as that breaks Spark Serialization. In order for // implicitly[Numeric[JvmType]] to be valid, we have to change JvmType from a type variable to a diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index a001d953592db..49fc4f70fdfae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst import java.io.{PrintWriter, ByteArrayOutputStream, FileInputStream, File} +import org.apache.spark.util.{Utils => SparkUtils} + package object util { /** * Returns a path to a temporary file that probably does not exist. @@ -54,7 +56,7 @@ package object util { def resourceToString( resource:String, encoding: String = "UTF-8", - classLoader: ClassLoader = this.getClass.getClassLoader) = { + classLoader: ClassLoader = SparkUtils.getSparkClassLoader) = { val inStream = classLoader.getResourceAsStream(resource) val outStream = new ByteArrayOutputStream try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala index 9ec31689b5098..4589129cd1c90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala @@ -32,18 +32,5 @@ package object sql { type Row = catalyst.expressions.Row - object Row { - /** - * This method can be used to extract fields from a [[Row]] object in a pattern match. Example: - * {{{ - * import org.apache.spark.sql._ - * - * val pairs = sql("SELECT key, value FROM src").rdd.map { - * case Row(key: Int, value: String) => - * key -> value - * } - * }}} - */ - def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) - } + val Row = catalyst.expressions.Row } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala index 52a205be3e9f4..2cd0d2b0e1385 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types._ @@ -27,7 +29,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ class ExpressionEvaluationSuite extends FunSuite { test("literals") { - assert((Literal(1) + Literal(1)).apply(null) === 2) + assert((Literal(1) + Literal(1)).eval(null) === 2) } /** @@ -60,7 +62,7 @@ class ExpressionEvaluationSuite extends FunSuite { notTrueTable.foreach { case (v, answer) => val expr = Not(Literal(v, BooleanType)) - val result = expr.apply(null) + val result = expr.eval(null) if (result != answer) fail(s"$expr should not evaluate to $result, expected: $answer") } } @@ -98,12 +100,15 @@ class ExpressionEvaluationSuite extends FunSuite { (null, false, null) :: (null, null, null) :: Nil) - def booleanLogicTest(name: String, op: (Expression, Expression) => Expression, truthTable: Seq[(Any, Any, Any)]) { + def booleanLogicTest( + name: String, + op: (Expression, Expression) => Expression, + truthTable: Seq[(Any, Any, Any)]) { test(s"3VL $name") { truthTable.foreach { case (l,r,answer) => val expr = op(Literal(l, BooleanType), Literal(r, BooleanType)) - val result = expr.apply(null) + val result = expr.eval(null) if (result != answer) fail(s"$expr should not evaluate to $result, expected: $answer") } @@ -111,7 +116,7 @@ class ExpressionEvaluationSuite extends FunSuite { } def evaluate(expression: Expression, inputRow: Row = EmptyRow): Any = { - expression.apply(inputRow) + expression.eval(inputRow) } def checkEvaluation(expression: Expression, expected: Any, inputRow: Row = EmptyRow): Unit = { @@ -139,7 +144,7 @@ class ExpressionEvaluationSuite extends FunSuite { checkEvaluation("abc" like "b%", false) checkEvaluation("abc" like "bc%", false) } - + test("LIKE Non-literal Regular Expression") { val regEx = 'a.string.at(0) checkEvaluation("abcd" like regEx, null, new GenericRow(Array[Any](null))) @@ -159,7 +164,7 @@ class ExpressionEvaluationSuite extends FunSuite { test("RLIKE literal Regular Expression") { checkEvaluation("abdef" rlike "abdef", true) checkEvaluation("abbbbc" rlike "a.*c", true) - + checkEvaluation("fofo" rlike "^fo", true) checkEvaluation("fo\no" rlike "^fo\no$", true) checkEvaluation("Bn" rlike "^Ba*n", true) @@ -191,5 +196,77 @@ class ExpressionEvaluationSuite extends FunSuite { evaluate("abbbbc" rlike regEx, new GenericRow(Array[Any]("**"))) } } + + test("data type casting") { + + val sts = "1970-01-01 00:00:01.0" + val ts = Timestamp.valueOf(sts) + + checkEvaluation("abdef" cast StringType, "abdef") + checkEvaluation("abdef" cast DecimalType, null) + checkEvaluation("abdef" cast TimestampType, null) + checkEvaluation("12.65" cast DecimalType, BigDecimal(12.65)) + + checkEvaluation(Literal(1) cast LongType, 1) + checkEvaluation(Cast(Literal(1) cast TimestampType, LongType), 1) + checkEvaluation(Cast(Literal(1.toDouble) cast TimestampType, DoubleType), 1.toDouble) + + checkEvaluation(Cast(Literal(sts) cast TimestampType, StringType), sts) + checkEvaluation(Cast(Literal(ts) cast StringType, TimestampType), ts) + + checkEvaluation(Cast("abdef" cast BinaryType, StringType), "abdef") + + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast ByteType, ShortType), IntegerType), FloatType), DoubleType), LongType), 5) + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast ByteType, TimestampType), DecimalType), LongType), StringType), ShortType), 5) + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast TimestampType, ByteType), DecimalType), LongType), StringType), ShortType), null) + checkEvaluation(Cast(Cast(Cast(Cast( + Cast("5" cast DecimalType, ByteType), TimestampType), LongType), StringType), ShortType), 5) + checkEvaluation(Literal(true) cast IntegerType, 1) + checkEvaluation(Literal(false) cast IntegerType, 0) + checkEvaluation(Cast(Literal(1) cast BooleanType, IntegerType), 1) + checkEvaluation(Cast(Literal(0) cast BooleanType, IntegerType), 0) + checkEvaluation("23" cast DoubleType, 23) + checkEvaluation("23" cast IntegerType, 23) + checkEvaluation("23" cast FloatType, 23) + checkEvaluation("23" cast DecimalType, 23) + checkEvaluation("23" cast ByteType, 23) + checkEvaluation("23" cast ShortType, 23) + checkEvaluation("2012-12-11" cast DoubleType, null) + checkEvaluation(Literal(123) cast IntegerType, 123) + + intercept[Exception] {evaluate(Literal(1) cast BinaryType, null)} + } + + test("timestamp") { + val ts1 = new Timestamp(12) + val ts2 = new Timestamp(123) + checkEvaluation(Literal("ab") < Literal("abc"), true) + checkEvaluation(Literal(ts1) < Literal(ts2), true) + } + + test("timestamp casting") { + val millis = 15 * 1000 + 2 + val ts = new Timestamp(millis) + val ts1 = new Timestamp(15 * 1000) // a timestamp without the milliseconds part + checkEvaluation(Cast(ts, ShortType), 15) + checkEvaluation(Cast(ts, IntegerType), 15) + checkEvaluation(Cast(ts, LongType), 15) + checkEvaluation(Cast(ts, FloatType), 15.002f) + checkEvaluation(Cast(ts, DoubleType), 15.002) + checkEvaluation(Cast(Cast(ts, ShortType), TimestampType), ts1) + checkEvaluation(Cast(Cast(ts, IntegerType), TimestampType), ts1) + checkEvaluation(Cast(Cast(ts, LongType), TimestampType), ts1) + checkEvaluation(Cast(Cast(millis.toFloat / 1000, TimestampType), FloatType), + millis.toFloat / 1000) + checkEvaluation(Cast(Cast(millis.toDouble / 1000, TimestampType), DoubleType), + millis.toDouble / 1000) + checkEvaluation(Cast(Literal(BigDecimal(1)) cast TimestampType, DecimalType), 1) + + // A test for higher precision than millis + checkEvaluation(Cast(Cast(0.00000001, TimestampType), DoubleType), 0.00000001) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 2ab14f48ccc8a..20dfba847790c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types.IntegerType +import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType} // For implicit conversions import org.apache.spark.sql.catalyst.dsl.plans._ diff --git a/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala b/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala deleted file mode 100644 index f1230e7526ab1..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/rdd/PartitionLocalRDDFunctions.scala +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import scala.language.implicitConversions - -import scala.reflect._ -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{Aggregator, InterruptibleIterator, Logging} -import org.apache.spark.util.collection.AppendOnlyMap - -/* Implicit conversions */ -import org.apache.spark.SparkContext._ - -/** - * Extra functions on RDDs that perform only local operations. These can be used when data has - * already been partitioned correctly. - */ -private[spark] class PartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) - extends Logging - with Serializable { - - /** - * Cogroup corresponding partitions of `this` and `other`. These two RDDs should have - * the same number of partitions. Partitions of these two RDDs are cogrouped - * according to the indexes of partitions. If we have two RDDs and - * each of them has n partitions, we will cogroup the partition i from `this` - * with the partition i from `other`. - * This function will not introduce a shuffling operation. - */ - def cogroupLocally[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = { - val cg = self.zipPartitions(other)((iter1:Iterator[(K, V)], iter2:Iterator[(K, W)]) => { - val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]] - - val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => { - if (hadVal) oldVal else Array.fill(2)(new ArrayBuffer[Any]) - } - - val getSeq = (k: K) => { - map.changeValue(k, update) - } - - iter1.foreach { kv => getSeq(kv._1)(0) += kv._2 } - iter2.foreach { kv => getSeq(kv._1)(1) += kv._2 } - - map.iterator - }).mapValues { case Seq(vs, ws) => (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])} - - cg - } - - /** - * Group the values for each key within a partition of the RDD into a single sequence. - * This function will not introduce a shuffling operation. - */ - def groupByKeyLocally(): RDD[(K, Seq[V])] = { - def createCombiner(v: V) = ArrayBuffer(v) - def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v - val aggregator = new Aggregator[K, V, ArrayBuffer[V]](createCombiner, mergeValue, _ ++ _) - val bufs = self.mapPartitionsWithContext((context, iter) => { - new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) - }, preservesPartitioning = true) - bufs.asInstanceOf[RDD[(K, Seq[V])]] - } - - /** - * Join corresponding partitions of `this` and `other`. - * If we have two RDDs and each of them has n partitions, - * we will join the partition i from `this` with the partition i from `other`. - * This function will not introduce a shuffling operation. - */ - def joinLocally[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = { - cogroupLocally(other).flatMapValues { - case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (v, w) - } - } -} - -private[spark] object PartitionLocalRDDFunctions { - implicit def rddToPartitionLocalRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) = - new PartitionLocalRDDFunctions(rdd) -} - - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f950ea08ec57a..d3d4c56bafe41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -21,24 +21,26 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import org.apache.spark.SparkContext +import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution._ /** - * ALPHA COMPONENT - * + * :: AlphaComponent :: * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]] * objects and the execution of SQL queries. * * @groupname userf Spark SQL Functions * @groupname Ungrouped Support functions for language integrated queries. */ +@AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) extends Logging with dsl.ExpressionConversions @@ -62,12 +64,12 @@ class SQLContext(@transient val sparkContext: SparkContext) new this.QueryExecution { val logical = plan } /** - * EXPERIMENTAL - * + * :: Experimental :: * Allows catalyst LogicalPlans to be executed as a SchemaRDD. Note that the LogicalPlan * interface is considered internal, and thus not guranteed to be stable. As a result, using * them directly is not reccomended. */ + @Experimental implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = new SchemaRDD(this, plan) /** @@ -79,12 +81,12 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))) /** - * Loads a parequet file, returning the result as a [[SchemaRDD]]. + * Loads a Parquet file, returning the result as a [[SchemaRDD]]. * * @group userf */ def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation("ParquetFile", path)) + new SchemaRDD(this, parquet.ParquetRelation(path)) /** @@ -111,11 +113,40 @@ class SQLContext(@transient val sparkContext: SparkContext) result } + /** Returns the specified table as a SchemaRDD */ + def table(tableName: String): SchemaRDD = + new SchemaRDD(this, catalog.lookupRelation(None, tableName)) + + /** Caches the specified table in-memory. */ + def cacheTable(tableName: String): Unit = { + val currentTable = catalog.lookupRelation(None, tableName) + val asInMemoryRelation = + InMemoryColumnarTableScan(currentTable.output, executePlan(currentTable).executedPlan) + + catalog.registerTable(None, tableName, SparkLogicalPlan(asInMemoryRelation)) + } + + /** Removes the specified table from the in-memory cache. */ + def uncacheTable(tableName: String): Unit = { + EliminateAnalysisOperators(catalog.lookupRelation(None, tableName)) match { + // This is kind of a hack to make sure that if this was just an RDD registered as a table, + // we reregister the RDD as a table. + case SparkLogicalPlan(inMem @ InMemoryColumnarTableScan(_, e: ExistingRdd)) => + inMem.cachedColumnBuffers.unpersist() + catalog.unregisterTable(None, tableName) + catalog.registerTable(None, tableName, SparkLogicalPlan(e)) + case SparkLogicalPlan(inMem: InMemoryColumnarTableScan) => + inMem.cachedColumnBuffers.unpersist() + catalog.unregisterTable(None, tableName) + case plan => throw new IllegalArgumentException(s"Table $tableName is not cached: $plan") + } + } + protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext = self.sparkContext val strategies: Seq[Strategy] = - TopK :: + TakeOrdered :: PartialAggregation :: HashJoin :: ParquetOperations :: @@ -194,6 +225,8 @@ class SQLContext(@transient val sparkContext: SparkContext) protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } + def simpleString: String = stringOrError(executedPlan) + override def toString: String = s"""== Logical Plan == |${stringOrError(analyzed)} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 770cabcb31d13..16da7fd92bffe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql +import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext} +import org.apache.spark.annotation.{AlphaComponent, Experimental} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.types.BooleanType -import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext} /** - * ALPHA COMPONENT - * + * :: AlphaComponent :: * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, * SchemaRDDs can be used in relational queries, as shown in the examples below. * @@ -90,25 +90,13 @@ import org.apache.spark.{Dependency, OneToOneDependency, Partition, TaskContext} * @groupprio schema -1 * @groupname Ungrouped Base RDD Functions */ +@AlphaComponent class SchemaRDD( @transient val sqlContext: SQLContext, - @transient val logicalPlan: LogicalPlan) - extends RDD[Row](sqlContext.sparkContext, Nil) { - - /** - * A lazily computed query execution workflow. All other RDD operations are passed - * through to the RDD that is produced by this workflow. - * - * We want this to be lazy because invoking the whole query optimization pipeline can be - * expensive. - */ - @transient - protected[spark] lazy val queryExecution = sqlContext.executePlan(logicalPlan) + @transient protected[spark] val logicalPlan: LogicalPlan) + extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike { - override def toString = - s"""${super.toString} - |== Query Plan == - |${queryExecution.executedPlan}""".stripMargin.trim + def baseSchemaRDD = this // ========================================================================================= // RDD functions: Copy the interal row representation so we present immutable data to users. @@ -161,17 +149,17 @@ class SchemaRDD( * * @param otherPlan the [[SchemaRDD]] that should be joined with this one. * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.` - * @param condition An optional condition for the join operation. This is equivilent to the `ON` - * clause in standard SQL. In the case of `Inner` joins, specifying a - * `condition` is equivilent to adding `where` clauses after the `join`. + * @param on An optional condition for the join operation. This is equivilent to the `ON` + * clause in standard SQL. In the case of `Inner` joins, specifying a + * `condition` is equivilent to adding `where` clauses after the `join`. * * @group Query */ def join( otherPlan: SchemaRDD, joinType: JoinType = Inner, - condition: Option[Expression] = None): SchemaRDD = - new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, condition)) + on: Option[Expression] = None): SchemaRDD = + new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, on)) /** * Sorts the results by the given expressions. @@ -208,14 +196,14 @@ class SchemaRDD( * with the same name, for example, when peforming self-joins. * * {{{ - * val x = schemaRDD.where('a === 1).subquery('x) - * val y = schemaRDD.where('a === 2).subquery('y) + * val x = schemaRDD.where('a === 1).as('x) + * val y = schemaRDD.where('a === 2).as('y) * x.join(y).where("x.a".attr === "y.a".attr), * }}} * * @group Query */ - def subquery(alias: Symbol) = + def as(alias: Symbol) = new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan)) /** @@ -241,8 +229,7 @@ class SchemaRDD( Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) /** - * EXPERIMENTAL - * + * :: Experimental :: * Filters tuples using a function over a `Dynamic` version of a given Row. DynamicRows use * scala's Dynamic trait to emulate an ORM of in a dynamically typed language. Since the type of * the column is not known at compile time, all attributes are converted to strings before @@ -254,18 +241,19 @@ class SchemaRDD( * * @group Query */ + @Experimental def where(dynamicUdf: (DynamicRow) => Boolean) = new SchemaRDD( sqlContext, Filter(ScalaUdf(dynamicUdf, BooleanType, Seq(WrapDynamic(logicalPlan.output))), logicalPlan)) /** - * EXPERIMENTAL - * + * :: Experimental :: * Returns a sampled version of the underlying dataset. * * @group Query */ + @Experimental def sample( fraction: Double, withReplacement: Boolean = true, @@ -273,8 +261,7 @@ class SchemaRDD( new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) /** - * EXPERIMENTAL - * + * :: Experimental :: * Applies the given Generator, or table generating function, to this relation. * * @param generator A table generating function. The API for such functions is likely to change @@ -290,6 +277,7 @@ class SchemaRDD( * * @group Query */ + @Experimental def generate( generator: Generator, join: Boolean = false, @@ -298,8 +286,7 @@ class SchemaRDD( new SchemaRDD(sqlContext, Generate(generator, join, outer, None, logicalPlan)) /** - * EXPERIMENTAL - * + * :: Experimental :: * Adds the rows from this RDD to the specified table. Note in a standard [[SQLContext]] there is * no notion of persistent tables, and thus queries that contain this operator will fail to * optimize. When working with an extension of a SQLContext that has a persistent catalog, such @@ -307,36 +294,18 @@ class SchemaRDD( * * @group schema */ + @Experimental def insertInto(tableName: String, overwrite: Boolean = false) = new SchemaRDD( sqlContext, InsertIntoTable(UnresolvedRelation(None, tableName), Map.empty, logicalPlan, overwrite)) - /** - * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that - * are written out using this method can be read back in as a SchemaRDD using the ``function - * - * @group schema - */ - def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } - - /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this SchemaRDD. - * - * @group schema - */ - def registerAsTable(tableName: String): Unit = { - sqlContext.registerRDDAsTable(this, tableName) - } - /** * Returns this RDD as a SchemaRDD. * @group schema */ def toSchemaRDD = this + /** FOR INTERNAL USE ONLY */ def analyze = sqlContext.analyzer(logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala new file mode 100644 index 0000000000000..3dd9897c0d3b8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -0,0 +1,65 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) + */ +trait SchemaRDDLike { + @transient val sqlContext: SQLContext + @transient protected[spark] val logicalPlan: LogicalPlan + + private[sql] def baseSchemaRDD: SchemaRDD + + /** + * A lazily computed query execution workflow. All other RDD operations are passed + * through to the RDD that is produced by this workflow. + * + * We want this to be lazy because invoking the whole query optimization pipeline can be + * expensive. + */ + @transient + protected[spark] lazy val queryExecution = sqlContext.executePlan(logicalPlan) + + override def toString = + s"""${super.toString} + |== Query Plan == + |${queryExecution.simpleString}""".stripMargin.trim + + /** + * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that + * are written out using this method can be read back in as a SchemaRDD using the ``function + * + * @group schema + */ + def saveAsParquetFile(path: String): Unit = { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + + /** + * Registers this RDD as a temporary table using the given name. The lifetime of this temporary + * table is tied to the [[SQLContext]] that was used to create this SchemaRDD. + * + * @group schema + */ + def registerAsTable(tableName: String): Unit = { + sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala new file mode 100644 index 0000000000000..573345e42c43c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -0,0 +1,100 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.api.java + +import java.beans.{Introspector, PropertyDescriptor} + +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.parquet.ParquetRelation +import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} + +/** + * The entry point for executing Spark SQL queries from a Java program. + */ +class JavaSQLContext(sparkContext: JavaSparkContext) { + + val sqlContext = new SQLContext(sparkContext.sc) + + /** + * Executes a query expressed in SQL, returning the result as a JavaSchemaRDD + */ + def sql(sqlQuery: String): JavaSchemaRDD = { + val result = new JavaSchemaRDD(sqlContext, sqlContext.parseSql(sqlQuery)) + // We force query optimization to happen right away instead of letting it happen lazily like + // when using the query DSL. This is so DDL commands behave as expected. This is only + // generates the RDD lineage for DML queries, but do not perform any execution. + result.queryExecution.toRdd + result + } + + /** + * Applies a schema to an RDD of Java Beans. + */ + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): JavaSchemaRDD = { + // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific. + val beanInfo = Introspector.getBeanInfo(beanClass) + + val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class") + val schema = fields.map { property => + val dataType = property.getPropertyType match { + case c: Class[_] if c == classOf[java.lang.String] => StringType + case c: Class[_] if c == java.lang.Short.TYPE => ShortType + case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType + case c: Class[_] if c == java.lang.Long.TYPE => LongType + case c: Class[_] if c == java.lang.Double.TYPE => DoubleType + case c: Class[_] if c == java.lang.Byte.TYPE => ByteType + case c: Class[_] if c == java.lang.Float.TYPE => FloatType + case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType + } + + AttributeReference(property.getName, dataType, true)() + } + + val className = beanClass.getCanonicalName + val rowRdd = rdd.rdd.mapPartitions { iter => + // BeanInfo is not serializable so we must rediscover it remotely for each partition. + val localBeanInfo = Introspector.getBeanInfo(Class.forName(className)) + val extractors = + localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) + + iter.map { row => + new GenericRow(extractors.map(e => e.invoke(row)).toArray[Any]): ScalaRow + } + } + new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) + } + + + /** + * Loads a parquet file, returning the result as a [[JavaSchemaRDD]]. + */ + def parquetFile(path: String): JavaSchemaRDD = + new JavaSchemaRDD(sqlContext, ParquetRelation(path)) + + + /** + * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only + * during the lifetime of this instance of SQLContext. + */ + def registerRDDAsTable(rdd: JavaSchemaRDD, tableName: String): Unit = { + sqlContext.registerRDDAsTable(rdd.baseSchemaRDD, tableName) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala new file mode 100644 index 0000000000000..d43d672938f51 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} +import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.rdd.RDD + +/** + * An RDD of [[Row]] objects that is returned as the result of a Spark SQL query. In addition to + * standard RDD operations, a JavaSchemaRDD can also be registered as a table in the JavaSQLContext + * that was used to create. Registering a JavaSchemaRDD allows its contents to be queried in + * future SQL statement. + * + * @groupname schema SchemaRDD Functions + * @groupprio schema -1 + * @groupname Ungrouped Base RDD Functions + */ +class JavaSchemaRDD( + @transient val sqlContext: SQLContext, + @transient protected[spark] val logicalPlan: LogicalPlan) + extends JavaRDDLike[Row, JavaRDD[Row]] + with SchemaRDDLike { + + private[sql] val baseSchemaRDD = new SchemaRDD(sqlContext, logicalPlan) + + override val classTag = scala.reflect.classTag[Row] + + override def wrapRDD(rdd: RDD[Row]): JavaRDD[Row] = JavaRDD.fromRDD(rdd) + + val rdd = baseSchemaRDD.map(new Row(_)) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala new file mode 100644 index 0000000000000..362fe769581d7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} + +/** + * A result row from a SparkSQL query. + */ +class Row(row: ScalaRow) extends Serializable { + + /** Returns the number of columns present in this Row. */ + def length: Int = row.length + + /** Returns the value of column `i`. */ + def get(i: Int): Any = + row(i) + + /** Returns true if value at column `i` is NULL. */ + def isNullAt(i: Int) = get(i) == null + + /** + * Returns the value of column `i` as an int. This function will throw an exception if the value + * is at `i` is not an integer, or if it is null. + */ + def getInt(i: Int): Int = + row.getInt(i) + + /** + * Returns the value of column `i` as a long. This function will throw an exception if the value + * is at `i` is not a long, or if it is null. + */ + def getLong(i: Int): Long = + row.getLong(i) + + /** + * Returns the value of column `i` as a double. This function will throw an exception if the + * value is at `i` is not a double, or if it is null. + */ + def getDouble(i: Int): Double = + row.getDouble(i) + + /** + * Returns the value of column `i` as a bool. This function will throw an exception if the value + * is at `i` is not a boolean, or if it is null. + */ + def getBoolean(i: Int): Boolean = + row.getBoolean(i) + + /** + * Returns the value of column `i` as a short. This function will throw an exception if the value + * is at `i` is not a short, or if it is null. + */ + def getShort(i: Int): Short = + row.getShort(i) + + /** + * Returns the value of column `i` as a byte. This function will throw an exception if the value + * is at `i` is not a byte, or if it is null. + */ + def getByte(i: Int): Byte = + row.getByte(i) + + /** + * Returns the value of column `i` as a float. This function will throw an exception if the value + * is at `i` is not a float, or if it is null. + */ + def getFloat(i: Int): Float = + row.getFloat(i) + + /** + * Returns the value of column `i` as a String. This function will throw an exception if the + * value is at `i` is not a String. + */ + def getString(i: Int): String = + row.getString(i) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index e0c98ecdf8f22..3c39e1d350fa8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -21,7 +21,7 @@ import java.nio.{ByteOrder, ByteBuffer} import org.apache.spark.sql.catalyst.types.{BinaryType, NativeType, DataType} import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.columnar.compression.CompressibleColumnAccessor /** * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is @@ -41,134 +41,80 @@ private[sql] trait ColumnAccessor { protected def underlyingBuffer: ByteBuffer } -private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType](buffer: ByteBuffer) +private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( + protected val buffer: ByteBuffer, + protected val columnType: ColumnType[T, JvmType]) extends ColumnAccessor { protected def initialize() {} - def columnType: ColumnType[T, JvmType] - def hasNext = buffer.hasRemaining def extractTo(row: MutableRow, ordinal: Int) { - doExtractTo(row, ordinal) + columnType.setField(row, ordinal, extractSingle(buffer)) } - protected def doExtractTo(row: MutableRow, ordinal: Int) + def extractSingle(buffer: ByteBuffer): JvmType = columnType.extract(buffer) protected def underlyingBuffer = buffer } private[sql] abstract class NativeColumnAccessor[T <: NativeType]( - buffer: ByteBuffer, - val columnType: NativeColumnType[T]) - extends BasicColumnAccessor[T, T#JvmType](buffer) + override protected val buffer: ByteBuffer, + override protected val columnType: NativeColumnType[T]) + extends BasicColumnAccessor(buffer, columnType) with NullableColumnAccessor + with CompressibleColumnAccessor[T] private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BOOLEAN) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setBoolean(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, BOOLEAN) private[sql] class IntColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, INT) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setInt(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, INT) private[sql] class ShortColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, SHORT) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setShort(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, SHORT) private[sql] class LongColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, LONG) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setLong(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, LONG) private[sql] class ByteColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BYTE) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setByte(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, BYTE) private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DOUBLE) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setDouble(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, DOUBLE) private[sql] class FloatColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, FLOAT) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setFloat(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, FLOAT) private[sql] class StringColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, STRING) { - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row.setString(ordinal, columnType.extract(buffer)) - } -} + extends NativeColumnAccessor(buffer, STRING) private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer) - with NullableColumnAccessor { - - def columnType = BINARY - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - row(ordinal) = columnType.extract(buffer) - } -} + extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) + with NullableColumnAccessor private[sql] class GenericColumnAccessor(buffer: ByteBuffer) - extends BasicColumnAccessor[DataType, Array[Byte]](buffer) - with NullableColumnAccessor { - - def columnType = GENERIC - - override protected def doExtractTo(row: MutableRow, ordinal: Int) { - val serialized = columnType.extract(buffer) - row(ordinal) = SparkSqlSerializer.deserialize[Any](serialized) - } -} + extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) + with NullableColumnAccessor private[sql] object ColumnAccessor { - def apply(b: ByteBuffer): ColumnAccessor = { - // The first 4 bytes in the buffer indicates the column type. - val buffer = b.duplicate().order(ByteOrder.nativeOrder()) - val columnTypeId = buffer.getInt() + def apply(buffer: ByteBuffer): ColumnAccessor = { + val dup = buffer.duplicate().order(ByteOrder.nativeOrder) + // The first 4 bytes in the buffer indicate the column type. + val columnTypeId = dup.getInt() columnTypeId match { - case INT.typeId => new IntColumnAccessor(buffer) - case LONG.typeId => new LongColumnAccessor(buffer) - case FLOAT.typeId => new FloatColumnAccessor(buffer) - case DOUBLE.typeId => new DoubleColumnAccessor(buffer) - case BOOLEAN.typeId => new BooleanColumnAccessor(buffer) - case BYTE.typeId => new ByteColumnAccessor(buffer) - case SHORT.typeId => new ShortColumnAccessor(buffer) - case STRING.typeId => new StringColumnAccessor(buffer) - case BINARY.typeId => new BinaryColumnAccessor(buffer) - case GENERIC.typeId => new GenericColumnAccessor(buffer) + case INT.typeId => new IntColumnAccessor(dup) + case LONG.typeId => new LongColumnAccessor(dup) + case FLOAT.typeId => new FloatColumnAccessor(dup) + case DOUBLE.typeId => new DoubleColumnAccessor(dup) + case BOOLEAN.typeId => new BooleanColumnAccessor(dup) + case BYTE.typeId => new ByteColumnAccessor(dup) + case SHORT.typeId => new ShortColumnAccessor(dup) + case STRING.typeId => new StringColumnAccessor(dup) + case BINARY.typeId => new BinaryColumnAccessor(dup) + case GENERIC.typeId => new GenericColumnAccessor(dup) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 3e622adfd3d6a..048ee66bff44b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -22,7 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnBuilder._ -import org.apache.spark.sql.execution.SparkSqlSerializer +import org.apache.spark.sql.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} private[sql] trait ColumnBuilder { /** @@ -30,37 +30,44 @@ private[sql] trait ColumnBuilder { */ def initialize(initialSize: Int, columnName: String = "") + /** + * Appends `row(ordinal)` to the column builder. + */ def appendFrom(row: Row, ordinal: Int) + /** + * Column statistics information + */ + def columnStats: ColumnStats[_, _] + + /** + * Returns the final columnar byte buffer. + */ def build(): ByteBuffer } -private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends ColumnBuilder { +private[sql] class BasicColumnBuilder[T <: DataType, JvmType]( + val columnStats: ColumnStats[T, JvmType], + val columnType: ColumnType[T, JvmType]) + extends ColumnBuilder { - private var columnName: String = _ - protected var buffer: ByteBuffer = _ + protected var columnName: String = _ - def columnType: ColumnType[T, JvmType] + protected var buffer: ByteBuffer = _ override def initialize(initialSize: Int, columnName: String = "") = { val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize this.columnName = columnName - buffer = ByteBuffer.allocate(4 + 4 + size * columnType.defaultSize) + + // Reserves 4 bytes for column type ID + buffer = ByteBuffer.allocate(4 + size * columnType.defaultSize) buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) } - // Have to give a concrete implementation to make mixin possible override def appendFrom(row: Row, ordinal: Int) { - doAppendFrom(row, ordinal) - } - - // Concrete `ColumnBuilder`s can override this method to append values - protected def doAppendFrom(row: Row, ordinal: Int) - - // Helper method to append primitive values (to avoid boxing cost) - protected def appendValue(v: JvmType) { - buffer = ensureFreeSpace(buffer, columnType.actualSize(v)) - columnType.append(v, buffer) + val field = columnType.getField(row, ordinal) + buffer = ensureFreeSpace(buffer, columnType.actualSize(field)) + columnType.append(field, buffer) } override def build() = { @@ -69,83 +76,39 @@ private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType] extends C } } -private[sql] abstract class NativeColumnBuilder[T <: NativeType]( - val columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T, T#JvmType] +private[sql] abstract class ComplexColumnBuilder[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType]) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) with NullableColumnBuilder -private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(BOOLEAN) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getBoolean(ordinal)) - } -} - -private[sql] class IntColumnBuilder extends NativeColumnBuilder(INT) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getInt(ordinal)) - } -} +private[sql] abstract class NativeColumnBuilder[T <: NativeType]( + override val columnStats: NativeColumnStats[T], + override val columnType: NativeColumnType[T]) + extends BasicColumnBuilder[T, T#JvmType](columnStats, columnType) + with NullableColumnBuilder + with AllCompressionSchemes + with CompressibleColumnBuilder[T] -private[sql] class ShortColumnBuilder extends NativeColumnBuilder(SHORT) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getShort(ordinal)) - } -} +private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class LongColumnBuilder extends NativeColumnBuilder(LONG) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getLong(ordinal)) - } -} +private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(BYTE) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getByte(ordinal)) - } -} +private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) -private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(DOUBLE) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getDouble(ordinal)) - } -} +private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(FLOAT) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getFloat(ordinal)) - } -} +private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(STRING) { - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row.getString(ordinal)) - } -} +private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class BinaryColumnBuilder - extends BasicColumnBuilder[BinaryType.type, Array[Byte]] - with NullableColumnBuilder { +private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) - def columnType = BINARY +private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) - override def doAppendFrom(row: Row, ordinal: Int) { - appendValue(row(ordinal).asInstanceOf[Array[Byte]]) - } -} +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(BINARY) // TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends BasicColumnBuilder[DataType, Array[Byte]] - with NullableColumnBuilder { - - def columnType = GENERIC - - override def doAppendFrom(row: Row, ordinal: Int) { - val serialized = SparkSqlSerializer.serialize(row(ordinal)) - buffer = ColumnBuilder.ensureFreeSpace(buffer, columnType.actualSize(serialized)) - columnType.append(serialized, buffer) - } -} +private[sql] class GenericColumnBuilder extends ComplexColumnBuilder(GENERIC) private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 10 * 1024 * 104 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala new file mode 100644 index 0000000000000..95602d321dc6f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.types._ + +/** + * Used to collect statistical information when building in-memory columns. + * + * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` + * brings significant performance penalty. + */ +private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable { + /** + * Closed lower bound of this column. + */ + def lowerBound: JvmType + + /** + * Closed upper bound of this column. + */ + def upperBound: JvmType + + /** + * Gathers statistics information from `row(ordinal)`. + */ + def gatherStats(row: Row, ordinal: Int) + + /** + * Returns `true` if `lower <= row(ordinal) <= upper`. + */ + def contains(row: Row, ordinal: Int): Boolean + + /** + * Returns `true` if `row(ordinal) < upper` holds. + */ + def isAbove(row: Row, ordinal: Int): Boolean + + /** + * Returns `true` if `lower < row(ordinal)` holds. + */ + def isBelow(row: Row, ordinal: Int): Boolean + + /** + * Returns `true` if `row(ordinal) <= upper` holds. + */ + def isAtOrAbove(row: Row, ordinal: Int): Boolean + + /** + * Returns `true` if `lower <= row(ordinal)` holds. + */ + def isAtOrBelow(row: Row, ordinal: Int): Boolean +} + +private[sql] sealed abstract class NativeColumnStats[T <: NativeType] + extends ColumnStats[T, T#JvmType] { + + type JvmType = T#JvmType + + protected var (_lower, _upper) = initialBounds + + def initialBounds: (JvmType, JvmType) + + protected def columnType: NativeColumnType[T] + + override def lowerBound: T#JvmType = _lower + + override def upperBound: T#JvmType = _upper + + override def isAtOrAbove(row: Row, ordinal: Int) = { + contains(row, ordinal) || isAbove(row, ordinal) + } + + override def isAtOrBelow(row: Row, ordinal: Int) = { + contains(row, ordinal) || isBelow(row, ordinal) + } +} + +private[sql] class NoopColumnStats[T <: DataType, JvmType] extends ColumnStats[T, JvmType] { + override def isAtOrBelow(row: Row, ordinal: Int) = true + + override def isAtOrAbove(row: Row, ordinal: Int) = true + + override def isBelow(row: Row, ordinal: Int) = true + + override def isAbove(row: Row, ordinal: Int) = true + + override def contains(row: Row, ordinal: Int) = true + + override def gatherStats(row: Row, ordinal: Int) {} + + override def upperBound = null.asInstanceOf[JvmType] + + override def lowerBound = null.asInstanceOf[JvmType] +} + +private[sql] abstract class BasicColumnStats[T <: NativeType]( + protected val columnType: NativeColumnType[T]) + extends NativeColumnStats[T] + +private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) { + override def initialBounds = (true, false) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) { + override def initialBounds = (Byte.MaxValue, Byte.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) { + override def initialBounds = (Short.MaxValue, Short.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class LongColumnStats extends BasicColumnStats(LONG) { + override def initialBounds = (Long.MaxValue, Long.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) { + override def initialBounds = (Double.MaxValue, Double.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) { + override def initialBounds = (Float.MaxValue, Float.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + } +} + +private[sql] object IntColumnStats { + val UNINITIALIZED = 0 + val INITIALIZED = 1 + val ASCENDING = 2 + val DESCENDING = 3 + val UNORDERED = 4 +} + +/** + * Statistical information for `Int` columns. More information is collected since `Int` is + * frequently used. Extra information include: + * + * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search + * is applicable when searching elements. + * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression + * scheme. + * + * (This two kinds of information are not used anywhere yet and might be removed later.) + */ +private[sql] class IntColumnStats extends BasicColumnStats(INT) { + import IntColumnStats._ + + private var orderedState = UNINITIALIZED + private var lastValue: Int = _ + private var _maxDelta: Int = _ + + def isAscending = orderedState != DESCENDING && orderedState != UNORDERED + def isDescending = orderedState != ASCENDING && orderedState != UNORDERED + def isOrdered = isAscending || isDescending + def maxDelta = _maxDelta + + override def initialBounds = (Int.MaxValue, Int.MinValue) + + override def isBelow(row: Row, ordinal: Int) = { + lowerBound < columnType.getField(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + columnType.getField(row, ordinal) < upperBound + } + + override def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + lowerBound <= field && field <= upperBound + } + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + + orderedState = orderedState match { + case UNINITIALIZED => + lastValue = field + INITIALIZED + + case INITIALIZED => + // If all the integers in the column are the same, ordered state is set to Ascending. + // TODO (lian) Confirm whether this is the standard behaviour. + val nextState = if (field >= lastValue) ASCENDING else DESCENDING + _maxDelta = math.abs(field - lastValue) + lastValue = field + nextState + + case ASCENDING if field < lastValue => + UNORDERED + + case DESCENDING if field > lastValue => + UNORDERED + + case state @ (ASCENDING | DESCENDING) => + _maxDelta = _maxDelta.max(field - lastValue) + lastValue = field + state + + case _ => + orderedState + } + } +} + +private[sql] class StringColumnStats extends BasicColumnStats(STRING) { + override def initialBounds = (null, null) + + override def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if ((upperBound eq null) || field.compareTo(upperBound) > 0) _upper = field + if ((lowerBound eq null) || field.compareTo(lowerBound) < 0) _lower = field + } + + override def contains(row: Row, ordinal: Int) = { + !(upperBound eq null) && { + val field = columnType.getField(row, ordinal) + lowerBound.compareTo(field) <= 0 && field.compareTo(upperBound) <= 0 + } + } + + override def isAbove(row: Row, ordinal: Int) = { + !(upperBound eq null) && { + val field = columnType.getField(row, ordinal) + field.compareTo(upperBound) < 0 + } + } + + override def isBelow(row: Row, ordinal: Int) = { + !(lowerBound eq null) && { + val field = columnType.getField(row, ordinal) + lowerBound.compareTo(field) < 0 + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index a452b86f0cda3..5be76890afe31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -19,7 +19,12 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.execution.SparkSqlSerializer /** * An abstract class that represents type of a column. Used to append/extract Java objects into/from @@ -50,10 +55,24 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( */ def actualSize(v: JvmType): Int = defaultSize + /** + * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs + * whenever possible. + */ + def getField(row: Row, ordinal: Int): JvmType + + /** + * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing + * costs whenever possible. + */ + def setField(row: MutableRow, ordinal: Int, value: JvmType) + /** * Creates a duplicated copy of the value. */ def clone(v: JvmType): JvmType = v + + override def toString = getClass.getSimpleName.stripSuffix("$") } private[sql] abstract class NativeColumnType[T <: NativeType]( @@ -65,7 +84,7 @@ private[sql] abstract class NativeColumnType[T <: NativeType]( /** * Scala TypeTag. Can be used to create primitive arrays and hash tables. */ - def scalaTag = dataType.tag + def scalaTag: TypeTag[dataType.JvmType] = dataType.tag } private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { @@ -76,6 +95,12 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { def extract(buffer: ByteBuffer) = { buffer.getInt() } + + override def setField(row: MutableRow, ordinal: Int, value: Int) { + row.setInt(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getInt(ordinal) } private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { @@ -86,6 +111,12 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { override def extract(buffer: ByteBuffer) = { buffer.getLong() } + + override def setField(row: MutableRow, ordinal: Int, value: Long) { + row.setLong(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getLong(ordinal) } private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { @@ -96,6 +127,12 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { override def extract(buffer: ByteBuffer) = { buffer.getFloat() } + + override def setField(row: MutableRow, ordinal: Int, value: Float) { + row.setFloat(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getFloat(ordinal) } private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { @@ -106,6 +143,12 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { override def extract(buffer: ByteBuffer) = { buffer.getDouble() } + + override def setField(row: MutableRow, ordinal: Int, value: Double) { + row.setDouble(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getDouble(ordinal) } private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { @@ -116,6 +159,12 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { override def extract(buffer: ByteBuffer) = { if (buffer.get() == 1) true else false } + + override def setField(row: MutableRow, ordinal: Int, value: Boolean) { + row.setBoolean(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getBoolean(ordinal) } private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { @@ -126,6 +175,12 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { override def extract(buffer: ByteBuffer) = { buffer.get() } + + override def setField(row: MutableRow, ordinal: Int, value: Byte) { + row.setByte(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getByte(ordinal) } private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { @@ -136,6 +191,12 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { override def extract(buffer: ByteBuffer) = { buffer.getShort() } + + override def setField(row: MutableRow, ordinal: Int, value: Short) { + row.setShort(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getShort(ordinal) } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { @@ -152,6 +213,12 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { buffer.get(stringBytes, 0, length) new String(stringBytes) } + + override def setField(row: MutableRow, ordinal: Int, value: String) { + row.setString(ordinal, value) + } + + override def getField(row: Row, ordinal: Int) = row.getString(ordinal) } private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( @@ -173,15 +240,27 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( } } -private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) +private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](8, 16) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + row(ordinal) = value + } + + override def getField(row: Row, ordinal: Int) = row(ordinal).asInstanceOf[Array[Byte]] +} // Used to process generic objects (all types other than those listed above). Objects should be // serialized first before appending to the column `ByteBuffer`, and is also extracted as serialized // byte array. -private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) +private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) { + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]) { + row(ordinal) = SparkSqlSerializer.deserialize[Any](value) + } + + override def getField(row: Row, ordinal: Int) = SparkSqlSerializer.serialize(row(ordinal)) +} private[sql] object ColumnType { - implicit def dataTypeToColumnType(dataType: DataType): ColumnType[_, _] = { + def apply(dataType: DataType): ColumnType[_, _] = { dataType match { case IntegerType => INT case LongType => LONG diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala rename to sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index f853759e5a306..8a24733047423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} import org.apache.spark.sql.execution.{SparkPlan, LeafNode} import org.apache.spark.sql.Row -/* Implicit conversions */ -import org.apache.spark.sql.columnar.ColumnType._ - private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], child: SparkPlan) extends LeafNode { @@ -32,8 +29,8 @@ private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], ch lazy val cachedColumnBuffers = { val output = child.output val cached = child.execute().mapPartitions { iterator => - val columnBuilders = output.map { a => - ColumnBuilder(a.dataType.typeId, 0, a.name) + val columnBuilders = output.map { attribute => + ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name) }.toArray var row: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index 2970c609b928d..7d49ab07f7a53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { private var nextNullIndex: Int = _ private var pos: Int = 0 - abstract override def initialize() { + abstract override protected def initialize() { nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) nullCount = nullsBuffer.getInt() nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index 048d1f05c7df2..2a3b6fc1e46d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -22,10 +22,18 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.Row /** - * Builds a nullable column. The byte buffer of a nullable column contains: - * - 4 bytes for the null count (number of nulls) - * - positions for each null, in ascending order - * - the non-null data (column data type, compression type, data...) + * A stackable trait used for building byte buffer for a column containing null values. Memory + * layout of the final byte buffer is: + * {{{ + * .----------------------- Column type ID (4 bytes) + * | .------------------- Null count N (4 bytes) + * | | .--------------- Null positions (4 x N bytes, empty if null count is zero) + * | | | .--------- Non-null elements + * V V V V + * +---+---+-----+---------+ + * | | | ... | ... ... | + * +---+---+-----+---------+ + * }}} */ private[sql] trait NullableColumnBuilder extends ColumnBuilder { private var nulls: ByteBuffer = _ @@ -59,19 +67,8 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { nulls.limit(nullDataLen) nulls.rewind() - // Column type ID is moved to the front, follows the null count, then non-null data - // - // +---------+ - // | 4 bytes | Column type ID - // +---------+ - // | 4 bytes | Null count - // +---------+ - // | ... | Null positions (if null count is not zero) - // +---------+ - // | ... | Non-null part (without column type ID) - // +---------+ val buffer = ByteBuffer - .allocate(4 + nullDataLen + nonNulls.limit) + .allocate(4 + 4 + nullDataLen + nonNulls.remaining()) .order(ByteOrder.nativeOrder()) .putInt(typeId) .putInt(nullCount) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala new file mode 100644 index 0000000000000..878cb84de106f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnAccessor.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.ByteBuffer + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar.{ColumnAccessor, NativeColumnAccessor} + +private[sql] trait CompressibleColumnAccessor[T <: NativeType] extends ColumnAccessor { + this: NativeColumnAccessor[T] => + + private var decoder: Decoder[T] = _ + + abstract override protected def initialize() = { + super.initialize() + decoder = CompressionScheme(underlyingBuffer.getInt()).decoder(buffer, columnType) + } + + abstract override def extractSingle(buffer: ByteBuffer): T#JvmType = decoder.next() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala new file mode 100644 index 0000000000000..fd3b1adf9687a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} + +/** + * A stackable trait that builds optionally compressed byte buffer for a column. Memory layout of + * the final byte buffer is: + * {{{ + * .--------------------------- Column type ID (4 bytes) + * | .----------------------- Null count N (4 bytes) + * | | .------------------- Null positions (4 x N bytes, empty if null count is zero) + * | | | .------------- Compression scheme ID (4 bytes) + * | | | | .--------- Compressed non-null elements + * V V V V V + * +---+---+-----+---+---------+ + * | | | ... | | ... ... | + * +---+---+-----+---+---------+ + * \-----------/ \-----------/ + * header body + * }}} + */ +private[sql] trait CompressibleColumnBuilder[T <: NativeType] + extends ColumnBuilder with Logging { + + this: NativeColumnBuilder[T] with WithCompressionSchemes => + + import CompressionScheme._ + + val compressionEncoders = schemes.filter(_.supports(columnType)).map(_.encoder[T]) + + protected def isWorthCompressing(encoder: Encoder[T]) = { + encoder.compressionRatio < 0.8 + } + + private def gatherCompressibilityStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + + var i = 0 + while (i < compressionEncoders.length) { + compressionEncoders(i).gatherCompressibilityStats(field, columnType) + i += 1 + } + } + + abstract override def appendFrom(row: Row, ordinal: Int) { + super.appendFrom(row, ordinal) + gatherCompressibilityStats(row, ordinal) + } + + abstract override def build() = { + val rawBuffer = super.build() + val encoder: Encoder[T] = { + val candidate = compressionEncoders.minBy(_.compressionRatio) + if (isWorthCompressing(candidate)) candidate else PassThrough.encoder + } + + val headerSize = columnHeaderSize(rawBuffer) + val compressedSize = if (encoder.compressedSize == 0) { + rawBuffer.limit - headerSize + } else { + encoder.compressedSize + } + + // Reserves 4 bytes for compression scheme ID + val compressedBuffer = ByteBuffer + .allocate(headerSize + 4 + compressedSize) + .order(ByteOrder.nativeOrder) + + copyColumnHeader(rawBuffer, compressedBuffer) + + logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + encoder.compress(rawBuffer, compressedBuffer, columnType) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala new file mode 100644 index 0000000000000..c605a8e4434e3 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.ByteBuffer + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} + +private[sql] trait Encoder[T <: NativeType] { + def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) {} + + def compressedSize: Int + + def uncompressedSize: Int + + def compressionRatio: Double = { + if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0 + } + + def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]): ByteBuffer +} + +private[sql] trait Decoder[T <: NativeType] extends Iterator[T#JvmType] + +private[sql] trait CompressionScheme { + def typeId: Int + + def supports(columnType: ColumnType[_, _]): Boolean + + def encoder[T <: NativeType]: Encoder[T] + + def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] +} + +private[sql] trait WithCompressionSchemes { + def schemes: Seq[CompressionScheme] +} + +private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { + override val schemes: Seq[CompressionScheme] = CompressionScheme.all +} + +private[sql] object CompressionScheme { + val all: Seq[CompressionScheme] = + Seq(PassThrough, RunLengthEncoding, DictionaryEncoding, BooleanBitSet, IntDelta, LongDelta) + + private val typeIdToScheme = all.map(scheme => scheme.typeId -> scheme).toMap + + def apply(typeId: Int): CompressionScheme = { + typeIdToScheme.getOrElse(typeId, throw new UnsupportedOperationException( + s"Unrecognized compression scheme type ID: $typeId")) + } + + def copyColumnHeader(from: ByteBuffer, to: ByteBuffer) { + // Writes column type ID + to.putInt(from.getInt()) + + // Writes null count + val nullCount = from.getInt() + to.putInt(nullCount) + + // Writes null positions + var i = 0 + while (i < nullCount) { + to.putInt(from.getInt()) + i += 1 + } + } + + def columnHeaderSize(columnBuffer: ByteBuffer): Int = { + val header = columnBuffer.duplicate() + val nullCount = header.getInt(4) + // Column type ID + null count + null positions + 4 + 4 + 4 * nullCount + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala new file mode 100644 index 0000000000000..e92cf5ac4f9df --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -0,0 +1,483 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.ByteBuffer + +import scala.collection.mutable +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.runtimeMirror + +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.columnar._ +import org.apache.spark.util.Utils + +private[sql] case object PassThrough extends CompressionScheme { + override val typeId = 0 + + override def supports(columnType: ColumnType[_, _]) = true + + override def encoder[T <: NativeType] = new this.Encoder[T] + + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new this.Decoder(buffer, columnType) + } + + class Encoder[T <: NativeType] extends compression.Encoder[T] { + override def uncompressedSize = 0 + + override def compressedSize = 0 + + override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + // Writes compression type ID and copies raw contents + to.putInt(PassThrough.typeId).put(from).rewind() + to + } + } + + class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + extends compression.Decoder[T] { + + override def next() = columnType.extract(buffer) + + override def hasNext = buffer.hasRemaining + } +} + +private[sql] case object RunLengthEncoding extends CompressionScheme { + override val typeId = 1 + + override def encoder[T <: NativeType] = new this.Encoder[T] + + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new this.Decoder(buffer, columnType) + } + + override def supports(columnType: ColumnType[_, _]) = columnType match { + case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true + case _ => false + } + + class Encoder[T <: NativeType] extends compression.Encoder[T] { + private var _uncompressedSize = 0 + private var _compressedSize = 0 + + // Using `MutableRow` to store the last value to avoid boxing/unboxing cost. + private val lastValue = new GenericMutableRow(1) + private var lastRun = 0 + + override def uncompressedSize = _uncompressedSize + + override def compressedSize = _compressedSize + + override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { + val actualSize = columnType.actualSize(value) + _uncompressedSize += actualSize + + if (lastValue.isNullAt(0)) { + columnType.setField(lastValue, 0, value) + lastRun = 1 + _compressedSize += actualSize + 4 + } else { + if (columnType.getField(lastValue, 0) == value) { + lastRun += 1 + } else { + _compressedSize += actualSize + 4 + columnType.setField(lastValue, 0, value) + lastRun = 1 + } + } + } + + override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + to.putInt(RunLengthEncoding.typeId) + + if (from.hasRemaining) { + var currentValue = columnType.extract(from) + var currentRun = 1 + + while (from.hasRemaining) { + val value = columnType.extract(from) + + if (value == currentValue) { + currentRun += 1 + } else { + // Writes current run + columnType.append(currentValue, to) + to.putInt(currentRun) + + // Resets current run + currentValue = value + currentRun = 1 + } + } + + // Writes the last run + columnType.append(currentValue, to) + to.putInt(currentRun) + } + + to.rewind() + to + } + } + + class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + extends compression.Decoder[T] { + + private var run = 0 + private var valueCount = 0 + private var currentValue: T#JvmType = _ + + override def next() = { + if (valueCount == run) { + currentValue = columnType.extract(buffer) + run = buffer.getInt() + valueCount = 1 + } else { + valueCount += 1 + } + + currentValue + } + + override def hasNext = buffer.hasRemaining + } +} + +private[sql] case object DictionaryEncoding extends CompressionScheme { + override val typeId = 2 + + // 32K unique values allowed + val MAX_DICT_SIZE = Short.MaxValue + + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new this.Decoder(buffer, columnType) + } + + override def encoder[T <: NativeType] = new this.Encoder[T] + + override def supports(columnType: ColumnType[_, _]) = columnType match { + case INT | LONG | STRING => true + case _ => false + } + + class Encoder[T <: NativeType] extends compression.Encoder[T] { + // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary + // overflows. + private var _uncompressedSize = 0 + + // If the number of distinct elements is too large, we discard the use of dictionary encoding + // and set the overflow flag to true. + private var overflow = false + + // Total number of elements. + private var count = 0 + + // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself. + private var values = new mutable.ArrayBuffer[T#JvmType](1024) + + // The dictionary that maps a value to the encoded short integer. + private val dictionary = mutable.HashMap.empty[Any, Short] + + // Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int` + // to store dictionary element count. + private var dictionarySize = 4 + + override def gatherCompressibilityStats(value: T#JvmType, columnType: NativeColumnType[T]) { + if (!overflow) { + val actualSize = columnType.actualSize(value) + count += 1 + _uncompressedSize += actualSize + + if (!dictionary.contains(value)) { + if (dictionary.size < MAX_DICT_SIZE) { + val clone = columnType.clone(value) + values += clone + dictionarySize += actualSize + dictionary(clone) = dictionary.size.toShort + } else { + overflow = true + values.clear() + dictionary.clear() + } + } + } + } + + override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[T]) = { + if (overflow) { + throw new IllegalStateException( + "Dictionary encoding should not be used because of dictionary overflow.") + } + + to.putInt(DictionaryEncoding.typeId) + .putInt(dictionary.size) + + var i = 0 + while (i < values.length) { + columnType.append(values(i), to) + i += 1 + } + + while (from.hasRemaining) { + to.putShort(dictionary(columnType.extract(from))) + } + + to.rewind() + to + } + + override def uncompressedSize = _uncompressedSize + + override def compressedSize = if (overflow) Int.MaxValue else dictionarySize + count * 2 + } + + class Decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) + extends compression.Decoder[T] { + + private val dictionary = { + // TODO Can we clean up this mess? Maybe move this to `DataType`? + implicit val classTag = { + val mirror = runtimeMirror(Utils.getSparkClassLoader) + ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe)) + } + + Array.fill(buffer.getInt()) { + columnType.extract(buffer) + } + } + + override def next() = dictionary(buffer.getShort()) + + override def hasNext = buffer.hasRemaining + } +} + +private[sql] case object BooleanBitSet extends CompressionScheme { + override val typeId = 3 + + val BITS_PER_LONG = 64 + + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] + } + + override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] + + override def supports(columnType: ColumnType[_, _]) = columnType == BOOLEAN + + class Encoder extends compression.Encoder[BooleanType.type] { + private var _uncompressedSize = 0 + + override def gatherCompressibilityStats( + value: Boolean, + columnType: NativeColumnType[BooleanType.type]) { + + _uncompressedSize += BOOLEAN.defaultSize + } + + override def compress( + from: ByteBuffer, + to: ByteBuffer, + columnType: NativeColumnType[BooleanType.type]) = { + + to.putInt(BooleanBitSet.typeId) + // Total element count (1 byte per Boolean value) + .putInt(from.remaining) + + while (from.remaining >= BITS_PER_LONG) { + var word = 0: Long + var i = 0 + + while (i < BITS_PER_LONG) { + if (BOOLEAN.extract(from)) { + word |= (1: Long) << i + } + i += 1 + } + + to.putLong(word) + } + + if (from.hasRemaining) { + var word = 0: Long + var i = 0 + + while (from.hasRemaining) { + if (BOOLEAN.extract(from)) { + word |= (1: Long) << i + } + i += 1 + } + + to.putLong(word) + } + + to.rewind() + to + } + + override def uncompressedSize = _uncompressedSize + + override def compressedSize = { + val extra = if (_uncompressedSize % BITS_PER_LONG == 0) 0 else 1 + (_uncompressedSize / BITS_PER_LONG + extra) * 8 + 4 + } + } + + class Decoder(buffer: ByteBuffer) extends compression.Decoder[BooleanType.type] { + private val count = buffer.getInt() + + private var currentWord = 0: Long + + private var visited: Int = 0 + + override def next(): Boolean = { + val bit = visited % BITS_PER_LONG + + visited += 1 + if (bit == 0) { + currentWord = buffer.getLong() + } + + ((currentWord >> bit) & 1) != 0 + } + + override def hasNext: Boolean = visited < count + } +} + +private[sql] sealed abstract class IntegralDelta[I <: IntegralType] extends CompressionScheme { + override def decoder[T <: NativeType](buffer: ByteBuffer, columnType: NativeColumnType[T]) = { + new this.Decoder(buffer, columnType.asInstanceOf[NativeColumnType[I]]) + .asInstanceOf[compression.Decoder[T]] + } + + override def encoder[T <: NativeType] = (new this.Encoder).asInstanceOf[compression.Encoder[T]] + + /** + * Computes `delta = x - y`, returns `(true, delta)` if `delta` can fit into a single byte, or + * `(false, 0: Byte)` otherwise. + */ + protected def byteSizedDelta(x: I#JvmType, y: I#JvmType): (Boolean, Byte) + + /** + * Simply computes `x + delta` + */ + protected def addDelta(x: I#JvmType, delta: Byte): I#JvmType + + class Encoder extends compression.Encoder[I] { + private var _compressedSize: Int = 0 + + private var _uncompressedSize: Int = 0 + + private var prev: I#JvmType = _ + + private var initial = true + + override def gatherCompressibilityStats(value: I#JvmType, columnType: NativeColumnType[I]) { + _uncompressedSize += columnType.defaultSize + + if (initial) { + initial = false + prev = value + _compressedSize += 1 + columnType.defaultSize + } else { + val (smallEnough, _) = byteSizedDelta(value, prev) + _compressedSize += (if (smallEnough) 1 else 1 + columnType.defaultSize) + } + } + + override def compress(from: ByteBuffer, to: ByteBuffer, columnType: NativeColumnType[I]) = { + to.putInt(typeId) + + if (from.hasRemaining) { + val prev = columnType.extract(from) + + to.put(Byte.MinValue) + columnType.append(prev, to) + + while (from.hasRemaining) { + val current = columnType.extract(from) + val (smallEnough, delta) = byteSizedDelta(current, prev) + + if (smallEnough) { + to.put(delta) + } else { + to.put(Byte.MinValue) + columnType.append(current, to) + } + } + } + + to.rewind() + to + } + + override def uncompressedSize = _uncompressedSize + + override def compressedSize = _compressedSize + } + + class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[I]) + extends compression.Decoder[I] { + + private var prev: I#JvmType = _ + + override def next() = { + val delta = buffer.get() + + if (delta > Byte.MinValue) { + addDelta(prev, delta) + } else { + prev = columnType.extract(buffer) + prev + } + } + + override def hasNext = buffer.hasRemaining + } +} + +private[sql] case object IntDelta extends IntegralDelta[IntegerType.type] { + override val typeId = 4 + + override def supports(columnType: ColumnType[_, _]) = columnType == INT + + override protected def addDelta(x: Int, delta: Byte) = x + delta + + override protected def byteSizedDelta(x: Int, y: Int): (Boolean, Byte) = { + val delta = x - y + if (delta < Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + } +} + +private[sql] case object LongDelta extends IntegralDelta[LongType.type] { + override val typeId = 5 + + override def supports(columnType: ColumnType[_, _]) = columnType == LONG + + override protected def addDelta(x: Long, delta: Byte) = x + delta + + override protected def byteSizedDelta(x: Long, y: Long): (Boolean, Byte) = { + val delta = x - y + if (delta < Byte.MaxValue) (true, delta.toByte) else (false, 0: Byte) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala new file mode 100644 index 0000000000000..3a4f071eebedf --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.HashMap + +import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ + +/** + * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each + * group. + * + * @param partial if true then aggregation is done partially on local data without shuffling to + * ensure all values where `groupingExpressions` are equal are present. + * @param groupingExpressions expressions that are evaluated to determine grouping. + * @param aggregateExpressions expressions that are computed for each group. + * @param child the input data source. + */ +case class Aggregate( + partial: Boolean, + groupingExpressions: Seq[Expression], + aggregateExpressions: Seq[NamedExpression], + child: SparkPlan)(@transient sc: SparkContext) + extends UnaryNode with NoBind { + + override def requiredChildDistribution = + if (partial) { + UnspecifiedDistribution :: Nil + } else { + if (groupingExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingExpressions) :: Nil + } + } + + override def otherCopyArgs = sc :: Nil + + // HACK: Generators don't correctly preserve their output through serializations so we grab + // out child's output attributes statically here. + private[this] val childOutput = child.output + + override def output = aggregateExpressions.map(_.toAttribute) + + /** + * An aggregate that needs to be computed for each row in a group. + * + * @param unbound Unbound version of this aggregate, used for result substitution. + * @param aggregate A bound copy of this aggregate used to create a new aggregation buffer. + * @param resultAttribute An attribute used to refer to the result of this aggregate in the final + * output. + */ + case class ComputedAggregate( + unbound: AggregateExpression, + aggregate: AggregateExpression, + resultAttribute: AttributeReference) + + /** A list of aggregates that need to be computed for each group. */ + @transient + private[this] lazy val computedAggregates = aggregateExpressions.flatMap { agg => + agg.collect { + case a: AggregateExpression => + ComputedAggregate( + a, + BindReferences.bindReference(a, childOutput).asInstanceOf[AggregateExpression], + AttributeReference(s"aggResult:$a", a.dataType, nullable = true)()) + } + }.toArray + + /** The schema of the result of all aggregate evaluations */ + @transient + private[this] lazy val computedSchema = computedAggregates.map(_.resultAttribute) + + /** Creates a new aggregate buffer for a group. */ + private[this] def newAggregateBuffer(): Array[AggregateFunction] = { + val buffer = new Array[AggregateFunction](computedAggregates.length) + var i = 0 + while (i < computedAggregates.length) { + buffer(i) = computedAggregates(i).aggregate.newInstance() + i += 1 + } + buffer + } + + /** Named attributes used to substitute grouping attributes into the final result. */ + @transient + private[this] lazy val namedGroups = groupingExpressions.map { + case ne: NamedExpression => ne -> ne.toAttribute + case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute + } + + /** + * A map of substitutions that are used to insert the aggregate expressions and grouping + * expression into the final result expression. + */ + @transient + private[this] lazy val resultMap = + (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute} ++ namedGroups).toMap + + /** + * Substituted version of aggregateExpressions expressions which are used to compute final + * output rows given a group and the result of all aggregate computations. + */ + @transient + private[this] lazy val resultExpressions = aggregateExpressions.map { agg => + agg.transform { + case e: Expression if resultMap.contains(e) => resultMap(e) + } + } + + override def execute() = attachTree(this, "execute") { + if (groupingExpressions.isEmpty) { + child.execute().mapPartitions { iter => + val buffer = newAggregateBuffer() + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + var i = 0 + while (i < buffer.length) { + buffer(i).update(currentRow) + i += 1 + } + } + val resultProjection = new Projection(resultExpressions, computedSchema) + val aggregateResults = new GenericMutableRow(computedAggregates.length) + + var i = 0 + while (i < buffer.length) { + aggregateResults(i) = buffer(i).eval(EmptyRow) + i += 1 + } + + Iterator(resultProjection(aggregateResults)) + } + } else { + child.execute().mapPartitions { iter => + val hashTable = new HashMap[Row, Array[AggregateFunction]] + val groupingProjection = new MutableProjection(groupingExpressions, childOutput) + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + val currentGroup = groupingProjection(currentRow) + var currentBuffer = hashTable.get(currentGroup) + if (currentBuffer == null) { + currentBuffer = newAggregateBuffer() + hashTable.put(currentGroup.copy(), currentBuffer) + } + + var i = 0 + while (i < currentBuffer.length) { + currentBuffer(i).update(currentRow) + i += 1 + } + } + + new Iterator[Row] { + private[this] val hashTableIter = hashTable.entrySet().iterator() + private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) + private[this] val resultProjection = + new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + private[this] val joinedRow = new JoinedRow + + override final def hasNext: Boolean = hashTableIter.hasNext + + override final def next(): Row = { + val currentEntry = hashTableIter.next() + val currentGroup = currentEntry.getKey + val currentBuffer = currentEntry.getValue + + var i = 0 + while (i < currentBuffer.length) { + // Evaluating an aggregate buffer returns the result. No row is required since we + // already added all rows in the group using update. + aggregateResults(i) = currentBuffer(i).eval(EmptyRow) + i += 1 + } + resultProjection(joinedRow(aggregateResults, currentGroup)) + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 869673b1fe978..450c142c0baa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -76,7 +76,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una */ object AddExchange extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - val numPartitions = 8 + val numPartitions = 150 def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index e902e6ced521d..cff4887936ae1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -36,10 +36,10 @@ case class Generate( child: SparkPlan) extends UnaryNode { - def output = + override def output = if (join) child.output ++ generator.output else generator.output - def execute() = { + override def execute() = { if (join) { child.execute().mapPartitions { iter => val nullValues = Seq.fill(generator.output.size)(Literal(null)) @@ -52,7 +52,7 @@ case class Generate( val joinedRow = new JoinedRow iter.flatMap {row => - val outputRows = generator(row) + val outputRows = generator.eval(row) if (outer && outputRows.isEmpty) { outerProjection(row) :: Nil } else { @@ -61,7 +61,7 @@ case class Generate( } } } else { - child.execute().mapPartitions(iter => iter.flatMap(generator)) + child.execute().mapPartitions(iter => iter.flatMap(row => generator.eval(row))) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index acb1ee83a72f6..daa423cb8ea1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.GenericRow import org.apache.spark.sql.catalyst.plans.{QueryPlan, logical} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging { self: Product => @@ -69,6 +70,8 @@ case class SparkLogicalPlan(alreadyPlanned: SparkPlan) SparkLogicalPlan( alreadyPlanned match { case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd) + case InMemoryColumnarTableScan(output, child) => + InMemoryColumnarTableScan(output.map(_.newInstance), child) case _ => sys.error("Multiple instance of the same relation detected.") }).asInstanceOf[this.type] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 915f551fb2f01..c30ae5bcc02d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -25,6 +25,7 @@ import com.esotericsoftware.kryo.{Serializer, Kryo} import org.apache.spark.{SparkEnv, SparkConf} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.MutablePair +import org.apache.spark.util.Utils class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { @@ -32,13 +33,19 @@ class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[Array[Any]]) + // This is kinda hacky... kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer) + kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer) + kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer) + kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer) + kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer) + kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]]) kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) kryo.setReferences(false) - kryo.setClassLoader(this.getClass.getClassLoader) + kryo.setClassLoader(Utils.getSparkClassLoader) kryo } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e35ac0b6ca95a..fe8bd5a508820 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -158,10 +158,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case other => other } - object TopK extends Strategy { + object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.StopAfter(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TopK(limit, order, planLater(child))(sparkContext) :: Nil + case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => + execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil case _ => Nil } } @@ -171,10 +171,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // TODO: need to support writing to other types of files. Unify the below code paths. case logical.WriteToFile(path, child) => val relation = - ParquetRelation.create(path, child, sparkContext.hadoopConfiguration, None) - InsertIntoParquetTable(relation, planLater(child))(sparkContext) :: Nil + ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) + InsertIntoParquetTable(relation, planLater(child), overwrite=true)(sparkContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child))(sparkContext) :: Nil + InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil case PhysicalOperation(projectList, filters, relation: ParquetRelation) => // TODO: Should be pushing down filters as well. pruneFilterProject( @@ -213,8 +213,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { sparkContext.parallelize(data.map(r => new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) execution.ExistingRdd(output, dataAsRdd) :: Nil - case logical.StopAfter(IntegerLiteral(limit), child) => - execution.StopAfter(limit, planLater(child))(sparkContext) :: Nil + case logical.Limit(IntegerLiteral(limit), child) => + execution.Limit(limit, planLater(child))(sparkContext) :: Nil case Unions(unionChildren) => execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil case logical.Generate(generator, join, outer, _, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala deleted file mode 100644 index 8515a18f18c55..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.SparkContext -import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ - -/* Implicit conversions */ -import org.apache.spark.rdd.PartitionLocalRDDFunctions._ - -/** - * Groups input data by `groupingExpressions` and computes the `aggregateExpressions` for each - * group. - * - * @param partial if true then aggregation is done partially on local data without shuffling to - * ensure all values where `groupingExpressions` are equal are present. - * @param groupingExpressions expressions that are evaluated to determine grouping. - * @param aggregateExpressions expressions that are computed for each group. - * @param child the input data source. - */ -case class Aggregate( - partial: Boolean, - groupingExpressions: Seq[Expression], - aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sc: SparkContext) - extends UnaryNode { - - override def requiredChildDistribution = - if (partial) { - UnspecifiedDistribution :: Nil - } else { - if (groupingExpressions == Nil) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingExpressions) :: Nil - } - } - - override def otherCopyArgs = sc :: Nil - - def output = aggregateExpressions.map(_.toAttribute) - - /* Replace all aggregate expressions with spark functions that will compute the result. */ - def createAggregateImplementations() = aggregateExpressions.map { agg => - val impl = agg transform { - case a: AggregateExpression => a.newInstance - } - - val remainingAttributes = impl.collect { case a: Attribute => a } - // If any references exist that are not inside agg functions then the must be grouping exprs - // in this case we must rebind them to the grouping tuple. - if (remainingAttributes.nonEmpty) { - val unaliasedAggregateExpr = agg transform { case Alias(c, _) => c } - - // An exact match with a grouping expression - val exactGroupingExpr = groupingExpressions.indexOf(unaliasedAggregateExpr) match { - case -1 => None - case ordinal => Some(BoundReference(ordinal, Alias(impl, "AGGEXPR")().toAttribute)) - } - - exactGroupingExpr.getOrElse( - sys.error(s"$agg is not in grouping expressions: $groupingExpressions")) - } else { - impl - } - } - - def execute() = attachTree(this, "execute") { - // TODO: If the child of it is an [[catalyst.execution.Exchange]], - // do not evaluate the groupingExpressions again since we have evaluated it - // in the [[catalyst.execution.Exchange]]. - val grouped = child.execute().mapPartitions { iter => - val buildGrouping = new Projection(groupingExpressions) - iter.map(row => (buildGrouping(row), row.copy())) - }.groupByKeyLocally() - - val result = grouped.map { case (group, rows) => - val aggImplementations = createAggregateImplementations() - - // Pull out all the functions so we can feed each row into them. - val aggFunctions = aggImplementations.flatMap(_ collect { case f: AggregateFunction => f }) - - rows.foreach { row => - aggFunctions.foreach(_.update(row)) - } - buildRow(aggImplementations.map(_.apply(group))) - } - - // TODO: THIS BREAKS PIPELINING, DOUBLE COMPUTES THE ANSWER, AND USES TOO MUCH MEMORY... - if (groupingExpressions.isEmpty && result.count == 0) { - // When there there is no output to the Aggregate operator, we still output an empty row. - val aggImplementations = createAggregateImplementations() - sc.makeRDD(buildRow(aggImplementations.map(_.apply(null))) :: Nil) - } else { - result - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 65cb8f8becefa..ab2e62463764a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -19,65 +19,88 @@ package org.apache.spark.sql.execution import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext - +import org.apache.spark.{HashPartitioner, SparkConf, SparkContext} +import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{OrderedDistribution, UnspecifiedDistribution} -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.util.MutablePair + case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { - def output = projectList.map(_.toAttribute) + override def output = projectList.map(_.toAttribute) - def execute() = child.execute().mapPartitions { iter => + override def execute() = child.execute().mapPartitions { iter => @transient val reusableProjection = new MutableProjection(projectList) iter.map(reusableProjection) } } case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { - def output = child.output + override def output = child.output - def execute() = child.execute().mapPartitions { iter => - iter.filter(condition.apply(_).asInstanceOf[Boolean]) + override def execute() = child.execute().mapPartitions { iter => + iter.filter(condition.eval(_).asInstanceOf[Boolean]) } } case class Sample(fraction: Double, withReplacement: Boolean, seed: Int, child: SparkPlan) extends UnaryNode { - def output = child.output + override def output = child.output // TODO: How to pick seed? - def execute() = child.execute().sample(withReplacement, fraction, seed) + override def execute() = child.execute().sample(withReplacement, fraction, seed) } case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes - def output = children.head.output - def execute() = sc.union(children.map(_.execute())) + override def output = children.head.output + override def execute() = sc.union(children.map(_.execute())) override def otherCopyArgs = sc :: Nil } -case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode { +/** + * Take the first limit elements. Note that the implementation is different depending on whether + * this is a terminal operator or not. If it is terminal and is invoked using executeCollect, + * this operator uses Spark's take method on the Spark driver. If it is not terminal or is + * invoked using execute, we first take the limit on each partition, and then repartition all the + * data to a single partition to compute the global limit. + */ +case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode { + // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: + // partition local limit -> exchange into one partition -> partition local limit again + override def otherCopyArgs = sc :: Nil - def output = child.output + override def output = child.output override def executeCollect() = child.execute().map(_.copy()).take(limit) - // TODO: Terminal split should be implemented differently from non-terminal split. - // TODO: Pick num splits based on |limit|. - def execute() = sc.makeRDD(executeCollect(), 1) + override def execute() = { + val rdd = child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Boolean, Row]() + iter.take(limit).map(row => mutablePair.update(false, row)) + } + val part = new HashPartitioner(1) + val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part) + shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false))) + shuffled.mapPartitions(_.take(limit).map(_._2)) + } } -case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sc: SparkContext) extends UnaryNode { +/** + * Take the first limit elements as defined by the sortOrder. This is logically equivalent to + * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but + * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion. + */ +case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) + (@transient sc: SparkContext) extends UnaryNode { override def otherCopyArgs = sc :: Nil - def output = child.output + override def output = child.output @transient lazy val ordering = new RowOrdering(sortOrder) @@ -86,7 +109,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - def execute() = sc.makeRDD(executeCollect(), 1) + override def execute() = sc.makeRDD(executeCollect(), 1) } @@ -101,7 +124,7 @@ case class Sort( @transient lazy val ordering = new RowOrdering(sortOrder) - def execute() = attachTree(this, "sort") { + override def execute() = attachTree(this, "sort") { // TODO: Optimize sorting operation? child.execute() .mapPartitions( @@ -109,7 +132,7 @@ case class Sort( preservesPartitioning = true) } - def output = child.output + override def output = child.output } object ExistingRdd { @@ -130,6 +153,6 @@ object ExistingRdd { } case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { - def execute() = rdd + override def execute() = rdd } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index 4ab755c096bd8..4d7c86a3a4fc7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -17,30 +17,29 @@ package org.apache.spark.sql.parquet -import java.io.{IOException, FileNotFoundException} - -import scala.collection.JavaConversions._ +import java.io.IOException import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapreduce.Job -import parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import parquet.hadoop.util.ContextUtil -import parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import parquet.hadoop.{ParquetOutputFormat, Footer, ParquetFileWriter, ParquetFileReader} +import parquet.hadoop.metadata.{CompressionCodecName, FileMetaData, ParquetMetadata} import parquet.io.api.{Binary, RecordConsumer} +import parquet.schema.{Type => ParquetType, PrimitiveType => ParquetPrimitiveType, MessageType, MessageTypeParser} import parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} import parquet.schema.Type.Repetition -import parquet.schema.{MessageType, MessageTypeParser} -import parquet.schema.{PrimitiveType => ParquetPrimitiveType} -import parquet.schema.{Type => ParquetType} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row} -import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} import org.apache.spark.sql.catalyst.types._ +// Implicits +import scala.collection.JavaConversions._ + /** * Relation that consists of data stored in a Parquet columnar format. * @@ -48,14 +47,14 @@ import org.apache.spark.sql.catalyst.types._ * of using this class directly. * * {{{ - * val parquetRDD = sqlContext.parquetFile("path/to/parequet.file") + * val parquetRDD = sqlContext.parquetFile("path/to/parquet.file") * }}} * - * @param tableName The name of the relation that can be used in queries. * @param path The path to the Parquet file. */ -case class ParquetRelation(tableName: String, path: String) - extends BaseRelation with MultiInstanceRelation { +private[sql] case class ParquetRelation(val path: String) + extends LeafNode with MultiInstanceRelation { + self: Product => /** Schema derived from ParquetFile */ def parquetSchema: MessageType = @@ -65,33 +64,42 @@ case class ParquetRelation(tableName: String, path: String) .getSchema /** Attributes */ - val attributes = + override val output = ParquetTypesConverter - .convertToAttributes(parquetSchema) - - /** Output */ - override val output = attributes - - // Parquet files have no concepts of keys, therefore no Partitioner - // Note: we could allow Block level access; needs to be thought through - override def isPartitioned = false + .convertToAttributes(parquetSchema) - override def newInstance = ParquetRelation(tableName, path).asInstanceOf[this.type] + override def newInstance = ParquetRelation(path).asInstanceOf[this.type] // Equals must also take into account the output attributes so that we can distinguish between // different instances of the same relation, override def equals(other: Any) = other match { case p: ParquetRelation => - p.tableName == tableName && p.path == path && p.output == output + p.path == path && p.output == output case _ => false } } -object ParquetRelation { +private[sql] object ParquetRelation { + + def enableLogForwarding() { + // Note: Logger.getLogger("parquet") has a default logger + // that appends to Console which needs to be cleared. + val parquetLogger = java.util.logging.Logger.getLogger("parquet") + parquetLogger.getHandlers.foreach(parquetLogger.removeHandler) + // TODO(witgo): Need to set the log level ? + // if(parquetLogger.getLevel != null) parquetLogger.setLevel(null) + if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true) + } // The element type for the RDDs that this relation maps to. type RowType = org.apache.spark.sql.catalyst.expressions.GenericMutableRow + // The compression type + type CompressionType = parquet.hadoop.metadata.CompressionCodecName + + // The default compression + val defaultCompression = CompressionCodecName.GZIP + /** * Creates a new ParquetRelation and underlying Parquetfile for the given LogicalPlan. Note that * this is used inside [[org.apache.spark.sql.execution.SparkStrategies SparkStrategies]] to @@ -100,24 +108,39 @@ object ParquetRelation { * * @param pathString The directory the Parquetfile will be stored in. * @param child The child node that will be used for extracting the schema. - * @param conf A configuration configuration to be used. - * @param tableName The name of the resulting relation. - * @return An empty ParquetRelation inferred metadata. + * @param conf A configuration to be used. + * @return An empty ParquetRelation with inferred metadata. */ def create(pathString: String, child: LogicalPlan, - conf: Configuration, - tableName: Option[String]): ParquetRelation = { + conf: Configuration): ParquetRelation = { if (!child.resolved) { throw new UnresolvedException[LogicalPlan]( child, "Attempt to create Parquet table from unresolved child (when schema is not available)") } + createEmpty(pathString, child.output, conf) + } - val name = s"${tableName.getOrElse(child.nodeName)}_parquet" + /** + * Creates an empty ParquetRelation and underlying Parquetfile that only + * consists of the Metadata for the given schema. + * + * @param pathString The directory the Parquetfile will be stored in. + * @param attributes The schema of the relation. + * @param conf A configuration to be used. + * @return An empty ParquetRelation. + */ + def createEmpty(pathString: String, + attributes: Seq[Attribute], + conf: Configuration): ParquetRelation = { val path = checkPath(pathString, conf) - ParquetTypesConverter.writeMetaData(child.output, path, conf) - new ParquetRelation(name, path.toString) + if (conf.get(ParquetOutputFormat.COMPRESSION) == null) { + conf.set(ParquetOutputFormat.COMPRESSION, ParquetRelation.defaultCompression.name()) + } + ParquetRelation.enableLogForwarding() + ParquetTypesConverter.writeMetaData(attributes, path, conf) + new ParquetRelation(path.toString) } private def checkPath(pathStr: String, conf: Configuration): Path = { @@ -143,7 +166,7 @@ object ParquetRelation { } } -object ParquetTypesConverter { +private[parquet] object ParquetTypesConverter { def toDataType(parquetType : ParquetPrimitiveTypeName): DataType = parquetType match { // for now map binary to string type // TODO: figure out how Parquet uses strings or why we can't use them in a MessageType schema @@ -242,6 +265,7 @@ object ParquetTypesConverter { extraMetadata, "Spark") + ParquetRelation.enableLogForwarding() ParquetFileWriter.writeMetadataFile( conf, path, @@ -268,16 +292,24 @@ object ParquetTypesConverter { throw new IllegalArgumentException(s"Incorrectly formatted Parquet metadata path $origPath") } val path = origPath.makeQualified(fs) + if (!fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException( + s"Expected $path for be a directory with Parquet files/metadata") + } + ParquetRelation.enableLogForwarding() val metadataPath = new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE) + // if this is a new table that was just created we will find only the metadata file if (fs.exists(metadataPath) && fs.isFile(metadataPath)) { - // TODO: improve exception handling, etc. ParquetFileReader.readFooter(conf, metadataPath) } else { - if (!fs.exists(path) || !fs.isFile(path)) { - throw new FileNotFoundException( - s"Could not find file ${path.toString} when trying to read metadata") + // there may be one or more Parquet files in the given directory + val footers = ParquetFileReader.readFooters(conf, fs.getFileStatus(path)) + // TODO: for now we assume that all footers (if there is more than one) have identical + // metadata; we may want to add a check here at some point + if (footers.size() == 0) { + throw new IllegalArgumentException(s"Could not find Parquet metadata at path $path") } - ParquetFileReader.readFooter(conf, path) + footers(0).getParquetMetadata } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 7285f5b88b9bf..d5846baa72ada 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -24,26 +24,29 @@ import java.util.Date import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat} +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat, FileOutputCommitter} -import parquet.hadoop.util.ContextUtil import parquet.hadoop.{ParquetInputFormat, ParquetOutputFormat} +import parquet.hadoop.util.ContextUtil import parquet.io.InvalidRecordException import parquet.schema.MessageType +import org.apache.spark.{SerializableWritable, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} -import org.apache.spark.{SerializableWritable, SparkContext, TaskContext} /** * Parquet table scan operator. Imports the file that backs the given * [[ParquetRelation]] as a RDD[Row]. */ case class ParquetTableScan( - @transient output: Seq[Attribute], - @transient relation: ParquetRelation, - @transient columnPruningPred: Option[Expression])( + // note: output cannot be transient, see + // https://issues.apache.org/jira/browse/SPARK-1367 + output: Seq[Attribute], + relation: ParquetRelation, + columnPruningPred: Option[Expression])( @transient val sc: SparkContext) extends LeafNode { @@ -53,6 +56,12 @@ case class ParquetTableScan( job, classOf[org.apache.spark.sql.parquet.RowReadSupport]) val conf: Configuration = ContextUtil.getConfiguration(job) + val fileList = FileSystemHelper.listFiles(relation.path, conf) + // add all paths in the directory but skip "hidden" ones such + // as "_SUCCESS" and "_metadata" + for (path <- fileList if !path.getName.startsWith("_")) { + NewFileInputFormat.addInputPath(job, path) + } conf.set( RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, ParquetTypesConverter.convertFromAttributes(output).toString) @@ -63,14 +72,12 @@ case class ParquetTableScan( ``FilteredRecordReader`` (via Configuration, for example). Simple filter-rows-by-column-values however should be supported. */ - sc.newAPIHadoopFile( - relation.path, - classOf[ParquetInputFormat[Row]], - classOf[Void], classOf[Row], - conf) + sc.newAPIHadoopRDD(conf, classOf[ParquetInputFormat[Row]], classOf[Void], classOf[Row]) .map(_._2) } + override def otherCopyArgs = sc :: Nil + /** * Applies a (candidate) projection. * @@ -108,15 +115,31 @@ case class ParquetTableScan( } } +/** + * Operator that acts as a sink for queries on RDDs and can be used to + * store the output inside a directory of Parquet files. This operator + * is similar to Hive's INSERT INTO TABLE operation in the sense that + * one can choose to either overwrite or append to a directory. Note + * that consecutive insertions to the same table must have compatible + * (source) schemas. + * + * WARNING: EXPERIMENTAL! InsertIntoParquetTable with overwrite=false may + * cause data corruption in the case that multiple users try to append to + * the same table simultaneously. Inserting into a table that was + * previously generated by other means (e.g., by creating an HDFS + * directory and importing Parquet files generated by other tools) may + * cause unpredicted behaviour and therefore results in a RuntimeException + * (only detected via filename pattern so will not catch all cases). + */ case class InsertIntoParquetTable( - @transient relation: ParquetRelation, - @transient child: SparkPlan)( + relation: ParquetRelation, + child: SparkPlan, + overwrite: Boolean = false)( @transient val sc: SparkContext) extends UnaryNode with SparkHadoopMapReduceUtil { /** - * Inserts all the rows in the Parquet file. Note that OVERWRITE is implicit, since - * Parquet files are write-once. + * Inserts all rows into the Parquet file. */ override def execute() = { // TODO: currently we do not check whether the "schema"s are compatible @@ -135,19 +158,21 @@ case class InsertIntoParquetTable( classOf[org.apache.spark.sql.parquet.RowWriteSupport]) // TODO: move that to function in object - val conf = job.getConfiguration + val conf = ContextUtil.getConfiguration(job) conf.set(RowWriteSupport.PARQUET_ROW_SCHEMA, relation.parquetSchema.toString) val fspath = new Path(relation.path) val fs = fspath.getFileSystem(conf) - try { - fs.delete(fspath, true) - } catch { - case e: IOException => - throw new IOException( - s"Unable to clear output directory ${fspath.toString} prior" - + s" to InsertIntoParquetTable:\n${e.toString}") + if (overwrite) { + try { + fs.delete(fspath, true) + } catch { + case e: IOException => + throw new IOException( + s"Unable to clear output directory ${fspath.toString} prior" + + s" to InsertIntoParquetTable:\n${e.toString}") + } } saveAsHadoopFile(childRdd, relation.path.toString, conf) @@ -157,6 +182,8 @@ case class InsertIntoParquetTable( override def output = child.output + override def otherCopyArgs = sc :: Nil + // based on ``saveAsNewAPIHadoopFile`` in [[PairRDDFunctions]] // TODO: Maybe PairRDDFunctions should use Product2 instead of Tuple2? // .. then we could use the default one and could use [[MutablePair]] @@ -167,15 +194,21 @@ case class InsertIntoParquetTable( conf: Configuration) { val job = new Job(conf) val keyType = classOf[Void] - val outputFormatType = classOf[parquet.hadoop.ParquetOutputFormat[Row]] job.setOutputKeyClass(keyType) job.setOutputValueClass(classOf[Row]) - val wrappedConf = new SerializableWritable(job.getConfiguration) NewFileOutputFormat.setOutputPath(job, new Path(path)) + val wrappedConf = new SerializableWritable(job.getConfiguration) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = sc.newRddId() + val taskIdOffset = + if (overwrite) 1 + else { + FileSystemHelper + .findMaxTaskId(NewFileOutputFormat.getOutputPath(job).toString, job.getConfiguration) + 1 + } + def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. @@ -184,7 +217,7 @@ case class InsertIntoParquetTable( val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) - val format = outputFormatType.newInstance + val format = new AppendingParquetOutputFormat(taskIdOffset) val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) val writer = format.getRecordWriter(hadoopContext) @@ -196,7 +229,7 @@ case class InsertIntoParquetTable( committer.commitTask(hadoopContext) return 1 } - val jobFormat = outputFormatType.newInstance + val jobFormat = new AppendingParquetOutputFormat(taskIdOffset) /* apparently we need a TaskAttemptID to construct an OutputCommitter; * however we're only going to use this local OutputCommitter for * setupJob/commitJob, so we just use a dummy "map" task. @@ -210,3 +243,55 @@ case class InsertIntoParquetTable( } } +// TODO: this will be able to append to directories it created itself, not necessarily +// to imported ones +private[parquet] class AppendingParquetOutputFormat(offset: Int) + extends parquet.hadoop.ParquetOutputFormat[Row] { + // override to accept existing directories as valid output directory + override def checkOutputSpecs(job: JobContext): Unit = {} + + // override to choose output filename so not overwrite existing ones + override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { + val taskId: TaskID = context.getTaskAttemptID.getTaskID + val partition: Int = taskId.getId + val filename = s"part-r-${partition + offset}.parquet" + val committer: FileOutputCommitter = + getOutputCommitter(context).asInstanceOf[FileOutputCommitter] + new Path(committer.getWorkPath, filename) + } +} + +private[parquet] object FileSystemHelper { + def listFiles(pathStr: String, conf: Configuration): Seq[Path] = { + val origPath = new Path(pathStr) + val fs = origPath.getFileSystem(conf) + if (fs == null) { + throw new IllegalArgumentException( + s"ParquetTableOperations: Path $origPath is incorrectly formatted") + } + val path = origPath.makeQualified(fs) + if (!fs.exists(path) || !fs.getFileStatus(path).isDir) { + throw new IllegalArgumentException( + s"ParquetTableOperations: path $path does not exist or is not a directory") + } + fs.listStatus(path).map(_.getPath) + } + + // finds the maximum taskid in the output file names at the given path + def findMaxTaskId(pathStr: String, conf: Configuration): Int = { + val files = FileSystemHelper.listFiles(pathStr, conf) + // filename pattern is part-r-.parquet + val nameP = new scala.util.matching.Regex("""part-r-(\d{1,}).parquet""", "taskid") + val hiddenFileP = new scala.util.matching.Regex("_.*") + files.map(_.getName).map { + case nameP(taskid) => taskid.toInt + case hiddenFileP() => 0 + case other: String => { + sys.error("ERROR: attempting to append to set of Parquet files and found file" + + s"that does not match name pattern: $other") + 0 + } + case _ => 0 + }.reduceLeft((a, b) => if (a < b) b else a) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index c21e400282004..84b1b4609458b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -35,7 +35,8 @@ import org.apache.spark.sql.catalyst.types._ * *@param root The root group converter for the record. */ -class RowRecordMaterializer(root: CatalystGroupConverter) extends RecordMaterializer[Row] { +private[parquet] class RowRecordMaterializer(root: CatalystGroupConverter) + extends RecordMaterializer[Row] { def this(parquetSchema: MessageType) = this(new CatalystGroupConverter(ParquetTypesConverter.convertToAttributes(parquetSchema))) @@ -48,14 +49,14 @@ class RowRecordMaterializer(root: CatalystGroupConverter) extends RecordMaterial /** * A `parquet.hadoop.api.ReadSupport` for Row objects. */ -class RowReadSupport extends ReadSupport[Row] with Logging { +private[parquet] class RowReadSupport extends ReadSupport[Row] with Logging { override def prepareForRead( conf: Configuration, stringMap: java.util.Map[String, String], fileSchema: MessageType, readContext: ReadContext): RecordMaterializer[Row] = { - log.debug(s"preparing for read with schema ${fileSchema.toString}") + log.debug(s"preparing for read with file schema $fileSchema") new RowRecordMaterializer(readContext.getRequestedSchema) } @@ -67,20 +68,20 @@ class RowReadSupport extends ReadSupport[Row] with Logging { configuration.get(RowReadSupport.PARQUET_ROW_REQUESTED_SCHEMA, fileSchema.toString) val requested_schema = MessageTypeParser.parseMessageType(requested_schema_string) - - log.debug(s"read support initialized for original schema ${requested_schema.toString}") + log.debug(s"read support initialized for requested schema $requested_schema") + ParquetRelation.enableLogForwarding() new ReadContext(requested_schema, keyValueMetaData) } } -object RowReadSupport { +private[parquet] object RowReadSupport { val PARQUET_ROW_REQUESTED_SCHEMA = "org.apache.spark.sql.parquet.row.requested_schema" } /** * A `parquet.hadoop.api.WriteSupport` for Row ojects. */ -class RowWriteSupport extends WriteSupport[Row] with Logging { +private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { def setSchema(schema: MessageType, configuration: Configuration) { // for testing this.schema = schema @@ -104,6 +105,8 @@ class RowWriteSupport extends WriteSupport[Row] with Logging { override def init(configuration: Configuration): WriteSupport.WriteContext = { schema = if (schema == null) getSchema(configuration) else schema attributes = ParquetTypesConverter.convertToAttributes(schema) + log.debug(s"write support initialized for requested schema $schema") + ParquetRelation.enableLogForwarding() new WriteSupport.WriteContext( schema, new java.util.HashMap[java.lang.String, java.lang.String]()) @@ -111,10 +114,16 @@ class RowWriteSupport extends WriteSupport[Row] with Logging { override def prepareForWrite(recordConsumer: RecordConsumer): Unit = { writer = recordConsumer + log.debug(s"preparing for write with schema $schema") } // TODO: add groups (nested fields) override def write(record: Row): Unit = { + if (attributes.size > record.size) { + throw new IndexOutOfBoundsException( + s"Trying to write more fields than contained in row (${attributes.size}>${record.size})") + } + var index = 0 writer.startMessage() while(index < attributes.size) { @@ -130,7 +139,7 @@ class RowWriteSupport extends WriteSupport[Row] with Logging { } } -object RowWriteSupport { +private[parquet] object RowWriteSupport { val PARQUET_ROW_SCHEMA: String = "org.apache.spark.sql.parquet.row.schema" } @@ -139,7 +148,7 @@ object RowWriteSupport { * * @param schema The corresponding Catalyst schema in the form of a list of attributes. */ -class CatalystGroupConverter( +private[parquet] class CatalystGroupConverter( schema: Seq[Attribute], protected[parquet] val current: ParquetRelation.RowType) extends GroupConverter { @@ -177,13 +186,12 @@ class CatalystGroupConverter( * @param parent The parent group converter. * @param fieldIndex The index inside the record. */ -class CatalystPrimitiveConverter( +private[parquet] class CatalystPrimitiveConverter( parent: CatalystGroupConverter, fieldIndex: Int) extends PrimitiveConverter { // TODO: consider refactoring these together with ParquetTypesConverter override def addBinary(value: Binary): Unit = - // TODO: fix this once a setBinary will become available in MutableRow - parent.getCurrentRecord.setByte(fieldIndex, value.getBytes.apply(0)) + parent.getCurrentRecord.update(fieldIndex, value.getBytes) override def addBoolean(value: Boolean): Unit = parent.getCurrentRecord.setBoolean(fieldIndex, value) @@ -208,10 +216,9 @@ class CatalystPrimitiveConverter( * @param parent The parent group converter. * @param fieldIndex The index inside the record. */ -class CatalystPrimitiveStringConverter( +private[parquet] class CatalystPrimitiveStringConverter( parent: CatalystGroupConverter, fieldIndex: Int) extends CatalystPrimitiveConverter(parent, fieldIndex) { override def addBinary(value: Binary): Unit = parent.getCurrentRecord.setString(fieldIndex, value.toStringUsingUTF8) } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala index 3340c3ff81f0a..728e3dd1dc02b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTestData.scala @@ -26,7 +26,7 @@ import parquet.hadoop.util.ContextUtil import parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.util.Utils object ParquetTestData { @@ -64,13 +64,13 @@ object ParquetTestData { "mylong:Long" ) - val testFile = getTempFilePath("testParquetFile").getCanonicalFile + val testDir = Utils.createTempDir() - lazy val testData = new ParquetRelation("testData", testFile.toURI.toString) + lazy val testData = new ParquetRelation(testDir.toURI.toString) def writeFile() = { - testFile.delete - val path: Path = new Path(testFile.toURI) + testDir.delete + val path: Path = new Path(new Path(testDir.toURI), new Path("part-r-0.parquet")) val job = new Job() val configuration: Configuration = ContextUtil.getConfiguration(job) val schema: MessageType = MessageTypeParser.parseMessageType(testSchema) diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties index 7bb6789bd33a5..dffd15a61838b 100644 --- a/sql/core/src/test/resources/log4j.properties +++ b/sql/core/src/test/resources/log4j.properties @@ -45,8 +45,6 @@ log4j.logger.org.apache.hadoop.hive.metastore.RetryingHMSHandler=OFF log4j.additivity.hive.ql.metadata.Hive=false log4j.logger.hive.ql.metadata.Hive=OFF -# Parquet logging -parquet.hadoop.InternalParquetRecordReader=WARN -log4j.logger.parquet.hadoop.InternalParquetRecordReader=WARN -parquet.hadoop.ParquetInputFormat=WARN -log4j.logger.parquet.hadoop.ParquetInputFormat=WARN +# Parquet related logging +log4j.logger.parquet.hadoop=WARN +log4j.logger.org.apache.spark.sql.parquet=INFO diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala new file mode 100644 index 0000000000000..7c6a642278226 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.FunSuite +import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan + +class CachedTableSuite extends QueryTest { + TestData // Load test tables. + + test("read from cached table and uncache") { + TestSQLContext.cacheTable("testData") + + checkAnswer( + TestSQLContext.table("testData"), + testData.collect().toSeq + ) + + TestSQLContext.table("testData").queryExecution.analyzed match { + case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case noCache => fail(s"No cache node found in plan $noCache") + } + + TestSQLContext.uncacheTable("testData") + + checkAnswer( + TestSQLContext.table("testData"), + testData.collect().toSeq + ) + + TestSQLContext.table("testData").queryExecution.analyzed match { + case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + fail(s"Table still cached after uncache: $cachePlan") + case noCache => // Table uncached successfully + } + } + + test("correct error on uncache of non-cached table") { + intercept[IllegalArgumentException] { + TestSQLContext.uncacheTable("testData") + } + } + + test("SELECT Star Cached Table") { + TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar") + TestSQLContext.cacheTable("selectStar") + TestSQLContext.sql("SELECT * FROM selectStar") + TestSQLContext.uncacheTable("selectStar") + } + + test("Self-join cached") { + TestSQLContext.cacheTable("testData") + TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key") + TestSQLContext.uncacheTable("testData") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index 2524a37cbac13..be0f4a4c73b36 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -119,8 +119,8 @@ class DslQuerySuite extends QueryTest { } test("inner join, where, multiple matches") { - val x = testData2.where('a === 1).subquery('x) - val y = testData2.where('a === 1).subquery('y) + val x = testData2.where('a === 1).as('x) + val y = testData2.where('a === 1).as('y) checkAnswer( x.join(y).where("x.a".attr === "y.a".attr), (1,1,1,1) :: @@ -131,8 +131,8 @@ class DslQuerySuite extends QueryTest { } test("inner join, no matches") { - val x = testData2.where('a === 1).subquery('x) - val y = testData2.where('a === 2).subquery('y) + val x = testData2.where('a === 1).as('x) + val y = testData2.where('a === 2).as('y) checkAnswer( x.join(y).where("x.a".attr === "y.a".attr), Nil) @@ -140,8 +140,8 @@ class DslQuerySuite extends QueryTest { test("big inner join, 4 matches per row") { val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.subquery('x) - val bigDataY = bigData.subquery('y) + val bigDataX = bigData.as('x) + val bigDataY = bigData.as('y) checkAnswer( bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), @@ -181,8 +181,8 @@ class DslQuerySuite extends QueryTest { } test("full outer join") { - val left = upperCaseData.where('N <= 4).subquery('left) - val right = upperCaseData.where('N >= 3).subquery('right) + val left = upperCaseData.where('N <= 4).as('left) + val right = upperCaseData.where('N >= 3).as('right) checkAnswer( left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala new file mode 100644 index 0000000000000..1cbf973c34917 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.sql.Timestamp + +import org.scalatest.FunSuite + +import org.apache.spark.sql.test.TestSQLContext._ + +case class ReflectData( + stringField: String, + intField: Int, + longField: Long, + floatField: Float, + doubleField: Double, + shortField: Short, + byteField: Byte, + booleanField: Boolean, + decimalField: BigDecimal, + timestampField: Timestamp, + seqInt: Seq[Int]) + +case class ReflectBinary(data: Array[Byte]) + +class ScalaReflectionRelationSuite extends FunSuite { + test("query case class RDD") { + val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, + BigDecimal(1), new Timestamp(12345), Seq(1,2,3)) + val rdd = sparkContext.parallelize(data :: Nil) + rdd.registerAsTable("reflectData") + + assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) + } + + // Equality is broken for Arrays, so we test that separately. + test("query binary data") { + val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) + rdd.registerAsTable("reflectBinary") + + val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] + assert(result.toSeq === Seq[Byte](1)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala new file mode 100644 index 0000000000000..def0e046a3831 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import scala.beans.BeanProperty + +import org.scalatest.FunSuite + +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.test.TestSQLContext + +// Implicits +import scala.collection.JavaConversions._ + +class PersonBean extends Serializable { + @BeanProperty + var name: String = _ + + @BeanProperty + var age: Int = _ +} + +class JavaSQLSuite extends FunSuite { + val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) + val javaSqlCtx = new JavaSQLContext(javaCtx) + + test("schema from JavaBeans") { + val person = new PersonBean + person.setName("Michael") + person.setAge(29) + + val rdd = javaCtx.parallelize(person :: Nil) + val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[PersonBean]) + + schemaRDD.registerAsTable("people") + javaSqlCtx.sql("SELECT * FROM people").collect() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala new file mode 100644 index 0000000000000..78640b876d4aa --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.types._ + +class ColumnStatsSuite extends FunSuite { + testColumnStats(classOf[BooleanColumnStats], BOOLEAN) + testColumnStats(classOf[ByteColumnStats], BYTE) + testColumnStats(classOf[ShortColumnStats], SHORT) + testColumnStats(classOf[IntColumnStats], INT) + testColumnStats(classOf[LongColumnStats], LONG) + testColumnStats(classOf[FloatColumnStats], FLOAT) + testColumnStats(classOf[DoubleColumnStats], DOUBLE) + testColumnStats(classOf[StringColumnStats], STRING) + + def testColumnStats[T <: NativeType, U <: NativeColumnStats[T]]( + columnStatsClass: Class[U], + columnType: NativeColumnType[T]) { + + val columnStatsName = columnStatsClass.getSimpleName + + test(s"$columnStatsName: empty") { + val columnStats = columnStatsClass.newInstance() + expectResult(columnStats.initialBounds, "Wrong initial bounds") { + (columnStats.lowerBound, columnStats.upperBound) + } + } + + test(s"$columnStatsName: non-empty") { + import ColumnarTestUtils._ + + val columnStats = columnStatsClass.newInstance() + val rows = Seq.fill(10)(makeRandomRow(columnType)) + rows.foreach(columnStats.gatherStats(_, 0)) + + val values = rows.map(_.head.asInstanceOf[T#JvmType]) + val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] + + expectResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound) + expectResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 2d431affbcfcc..1d3608ed2d9ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -19,46 +19,56 @@ package org.apache.spark.sql.columnar import java.nio.ByteBuffer -import scala.util.Random - import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer class ColumnTypeSuite extends FunSuite { - val columnTypes = Seq(INT, SHORT, LONG, BYTE, DOUBLE, FLOAT, STRING, BINARY, GENERIC) + val DEFAULT_BUFFER_SIZE = 512 test("defaultSize") { - val defaultSize = Seq(4, 2, 8, 1, 8, 4, 8, 16, 16) + val checks = Map( + INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, + BOOLEAN -> 1, STRING -> 8, BINARY -> 16, GENERIC -> 16) - columnTypes.zip(defaultSize).foreach { case (columnType, size) => - assert(columnType.defaultSize === size) + checks.foreach { case (columnType, expectedSize) => + expectResult(expectedSize, s"Wrong defaultSize for $columnType") { + columnType.defaultSize + } } } test("actualSize") { - val expectedSizes = Seq(4, 2, 8, 1, 8, 4, 4 + 5, 4 + 4, 4 + 11) - val actualSizes = Seq( - INT.actualSize(Int.MaxValue), - SHORT.actualSize(Short.MaxValue), - LONG.actualSize(Long.MaxValue), - BYTE.actualSize(Byte.MaxValue), - DOUBLE.actualSize(Double.MaxValue), - FLOAT.actualSize(Float.MaxValue), - STRING.actualSize("hello"), - BINARY.actualSize(new Array[Byte](4)), - GENERIC.actualSize(SparkSqlSerializer.serialize(Map(1 -> "a")))) - - expectedSizes.zip(actualSizes).foreach { case (expected, actual) => - assert(expected === actual) + def checkActualSize[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType], + value: JvmType, + expected: Int) { + + expectResult(expected, s"Wrong actualSize for $columnType") { + columnType.actualSize(value) + } } + + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(LONG, Long.MaxValue, 8) + checkActualSize(BYTE, Byte.MaxValue, 1) + checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(FLOAT, Float.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(STRING, "hello", 4 + 5) + + val binary = Array.fill[Byte](4)(0: Byte) + checkActualSize(BINARY, binary, 4 + 4) + + val generic = Map(1 -> "a") + checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11) } - testNumericColumnType[BooleanType.type, Boolean]( + testNativeColumnType[BooleanType.type]( BOOLEAN, - Array.fill(4)(Random.nextBoolean()), - ByteBuffer.allocate(32), (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) }, @@ -66,105 +76,42 @@ class ColumnTypeSuite extends FunSuite { buffer.get() == 1 }) - testNumericColumnType[IntegerType.type, Int]( - INT, - Array.fill(4)(Random.nextInt()), - ByteBuffer.allocate(32), - (_: ByteBuffer).putInt(_), - (_: ByteBuffer).getInt) - - testNumericColumnType[ShortType.type, Short]( - SHORT, - Array.fill(4)(Random.nextInt(Short.MaxValue).asInstanceOf[Short]), - ByteBuffer.allocate(32), - (_: ByteBuffer).putShort(_), - (_: ByteBuffer).getShort) - - testNumericColumnType[LongType.type, Long]( - LONG, - Array.fill(4)(Random.nextLong()), - ByteBuffer.allocate(64), - (_: ByteBuffer).putLong(_), - (_: ByteBuffer).getLong) - - testNumericColumnType[ByteType.type, Byte]( - BYTE, - Array.fill(4)(Random.nextInt(Byte.MaxValue).asInstanceOf[Byte]), - ByteBuffer.allocate(64), - (_: ByteBuffer).put(_), - (_: ByteBuffer).get) - - testNumericColumnType[DoubleType.type, Double]( - DOUBLE, - Array.fill(4)(Random.nextDouble()), - ByteBuffer.allocate(64), - (_: ByteBuffer).putDouble(_), - (_: ByteBuffer).getDouble) - - testNumericColumnType[FloatType.type, Float]( - FLOAT, - Array.fill(4)(Random.nextFloat()), - ByteBuffer.allocate(64), - (_: ByteBuffer).putFloat(_), - (_: ByteBuffer).getFloat) - - test("STRING") { - val buffer = ByteBuffer.allocate(128) - val seq = Array("hello", "world", "spark", "sql") - - seq.map(_.getBytes).foreach { bytes: Array[Byte] => - buffer.putInt(bytes.length).put(bytes) - } + testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) - buffer.rewind() - seq.foreach { s => - assert(s === STRING.extract(buffer)) - } + testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) - buffer.rewind() - seq.foreach(STRING.append(_, buffer)) + testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) - buffer.rewind() - seq.foreach { s => - val length = buffer.getInt - assert(length === s.getBytes.length) + testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + + testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + + testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) + testNativeColumnType[StringType.type]( + STRING, + (buffer: ByteBuffer, string: String) => { + val bytes = string.getBytes() + buffer.putInt(bytes.length).put(string.getBytes) + }, + (buffer: ByteBuffer) => { + val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes, 0, length) - assert(s === new String(bytes)) - } - } - - test("BINARY") { - val buffer = ByteBuffer.allocate(128) - val seq = Array.fill(4) { - val bytes = new Array[Byte](4) - Random.nextBytes(bytes) - bytes - } + new String(bytes) + }) - seq.foreach { bytes => + testColumnType[BinaryType.type, Array[Byte]]( + BINARY, + (buffer: ByteBuffer, bytes: Array[Byte]) => { buffer.putInt(bytes.length).put(bytes) - } - - buffer.rewind() - seq.foreach { b => - assert(b === BINARY.extract(buffer)) - } - - buffer.rewind() - seq.foreach(BINARY.append(_, buffer)) - - buffer.rewind() - seq.foreach { b => - val length = buffer.getInt - assert(length === b.length) - + }, + (buffer: ByteBuffer) => { + val length = buffer.getInt() val bytes = new Array[Byte](length) buffer.get(bytes, 0, length) - assert(b === bytes) - } - } + bytes + }) test("GENERIC") { val buffer = ByteBuffer.allocate(512) @@ -177,43 +124,58 @@ class ColumnTypeSuite extends FunSuite { val length = buffer.getInt() assert(length === serializedObj.length) - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - assert(obj === SparkSqlSerializer.deserialize(bytes)) + expectResult(obj, "Deserialized object didn't equal to the original object") { + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + SparkSqlSerializer.deserialize(bytes) + } buffer.rewind() buffer.putInt(serializedObj.length).put(serializedObj) - buffer.rewind() - assert(obj === SparkSqlSerializer.deserialize(GENERIC.extract(buffer))) + expectResult(obj, "Deserialized object didn't equal to the original object") { + buffer.rewind() + SparkSqlSerializer.deserialize(GENERIC.extract(buffer)) + } + } + + def testNativeColumnType[T <: NativeType]( + columnType: NativeColumnType[T], + putter: (ByteBuffer, T#JvmType) => Unit, + getter: (ByteBuffer) => T#JvmType) { + + testColumnType[T, T#JvmType](columnType, putter, getter) } - def testNumericColumnType[T <: DataType, JvmType]( + def testColumnType[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], - seq: Seq[JvmType], - buffer: ByteBuffer, putter: (ByteBuffer, JvmType) => Unit, getter: (ByteBuffer) => JvmType) { - val columnTypeName = columnType.getClass.getSimpleName.stripSuffix("$") + val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) + val seq = (0 until 4).map(_ => makeRandomValue(columnType)) - test(s"$columnTypeName.extract") { + test(s"$columnType.extract") { buffer.rewind() seq.foreach(putter(buffer, _)) buffer.rewind() - seq.foreach { i => - assert(i === columnType.extract(buffer)) + seq.foreach { expected => + assert( + expected === columnType.extract(buffer), + "Extracted value didn't equal to the original one") } } - test(s"$columnTypeName.append") { + test(s"$columnType.append") { buffer.rewind() seq.foreach(columnType.append(_, buffer)) buffer.rewind() - seq.foreach { i => - assert(i === getter(buffer)) + seq.foreach { expected => + assert( + expected === getter(buffer), + "Extracted value didn't equal to the original one") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala index 928851a385d41..2ed4cf2170f9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarQuerySuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.{QueryTest, TestData} import org.apache.spark.sql.execution.SparkLogicalPlan import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{TestData, DslQuerySuite} -class ColumnarQuerySuite extends DslQuerySuite { +class ColumnarQuerySuite extends QueryTest { import TestData._ import TestSQLContext._ @@ -31,4 +31,12 @@ class ColumnarQuerySuite extends DslQuerySuite { checkAnswer(scan, testData.collect().toSeq) } + + test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { + val plan = TestSQLContext.executePlan(testData.logicalPlan).executedPlan + val scan = SparkLogicalPlan(InMemoryColumnarTableScan(plan.output, plan)) + + checkAnswer(scan, testData.collect().toSeq) + checkAnswer(scan, testData.collect().toSeq) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala deleted file mode 100644 index ddcdede8d1a4a..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestData.scala +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.columnar - -import scala.util.Random - -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow - -// TODO Enrich test data -object ColumnarTestData { - object GenericMutableRow { - def apply(values: Any*) = { - val row = new GenericMutableRow(values.length) - row.indices.foreach { i => - row(i) = values(i) - } - row - } - } - - def randomBytes(length: Int) = { - val bytes = new Array[Byte](length) - Random.nextBytes(bytes) - bytes - } - - val nonNullRandomRow = GenericMutableRow( - Random.nextInt(), - Random.nextLong(), - Random.nextFloat(), - Random.nextDouble(), - Random.nextBoolean(), - Random.nextInt(Byte.MaxValue).asInstanceOf[Byte], - Random.nextInt(Short.MaxValue).asInstanceOf[Short], - Random.nextString(Random.nextInt(64)), - randomBytes(Random.nextInt(64)), - Map(Random.nextInt() -> Random.nextString(4))) - - val nullRow = GenericMutableRow(Seq.fill(10)(null): _*) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala new file mode 100644 index 0000000000000..04bdc43d95328 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import scala.collection.immutable.HashSet +import scala.util.Random + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.types.{DataType, NativeType} + +object ColumnarTestUtils { + def makeNullRow(length: Int) = { + val row = new GenericMutableRow(length) + (0 until length).foreach(row.setNullAt) + row + } + + def makeRandomValue[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]): JvmType = { + def randomBytes(length: Int) = { + val bytes = new Array[Byte](length) + Random.nextBytes(bytes) + bytes + } + + (columnType match { + case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte + case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort + case INT => Random.nextInt() + case LONG => Random.nextLong() + case FLOAT => Random.nextFloat() + case DOUBLE => Random.nextDouble() + case STRING => Random.nextString(Random.nextInt(32)) + case BOOLEAN => Random.nextBoolean() + case BINARY => randomBytes(Random.nextInt(32)) + case _ => + // Using a random one-element map instead of an arbitrary object + Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) + }).asInstanceOf[JvmType] + } + + def makeRandomValues( + head: ColumnType[_ <: DataType, _], + tail: ColumnType[_ <: DataType, _]*): Seq[Any] = makeRandomValues(Seq(head) ++ tail) + + def makeRandomValues(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Seq[Any] = { + columnTypes.map(makeRandomValue(_)) + } + + def makeUniqueRandomValues[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType], + count: Int): Seq[JvmType] = { + + Iterator.iterate(HashSet.empty[JvmType]) { set => + set + Iterator.continually(makeRandomValue(columnType)).filterNot(set.contains).next() + }.drop(count).next().toSeq + } + + def makeRandomRow( + head: ColumnType[_ <: DataType, _], + tail: ColumnType[_ <: DataType, _]*): Row = makeRandomRow(Seq(head) ++ tail) + + def makeRandomRow(columnTypes: Seq[ColumnType[_ <: DataType, _]]): Row = { + val row = new GenericMutableRow(columnTypes.length) + makeRandomValues(columnTypes).zipWithIndex.foreach { case (value, index) => + row(index) = value + } + row + } + + def makeUniqueValuesAndSingleValueRows[T <: NativeType]( + columnType: NativeColumnType[T], + count: Int) = { + + val values = makeUniqueRandomValues(columnType, count) + val rows = values.map { value => + val row = new GenericMutableRow(1) + row(0) = value + row + } + + (values, rows) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index d413d483f4e7e..4a21eb6201a69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -17,12 +17,29 @@ package org.apache.spark.sql.columnar +import java.nio.ByteBuffer + import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.DataType + import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.types.DataType + +class TestNullableColumnAccessor[T <: DataType, JvmType]( + buffer: ByteBuffer, + columnType: ColumnType[T, JvmType]) + extends BasicColumnAccessor(buffer, columnType) + with NullableColumnAccessor + +object TestNullableColumnAccessor { + def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = { + // Skips the column type ID + buffer.getInt() + new TestNullableColumnAccessor(buffer, columnType) + } +} class NullableColumnAccessorSuite extends FunSuite { - import ColumnarTestData._ + import ColumnarTestUtils._ Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { testNullableColumnAccessor(_) @@ -30,30 +47,32 @@ class NullableColumnAccessorSuite extends FunSuite { def testNullableColumnAccessor[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + val nullRow = makeNullRow(1) - test(s"$typeName accessor: empty column") { - val builder = ColumnBuilder(columnType.typeId, 4) - val accessor = ColumnAccessor(builder.build()) + test(s"Nullable $typeName column accessor: empty column") { + val builder = TestNullableColumnBuilder(columnType) + val accessor = TestNullableColumnAccessor(builder.build(), columnType) assert(!accessor.hasNext) } - test(s"$typeName accessor: access null values") { - val builder = ColumnBuilder(columnType.typeId, 4) + test(s"Nullable $typeName column accessor: access null values") { + val builder = TestNullableColumnBuilder(columnType) + val randomRow = makeRandomRow(columnType) (0 until 4).foreach { _ => - builder.appendFrom(nonNullRandomRow, columnType.typeId) - builder.appendFrom(nullRow, columnType.typeId) + builder.appendFrom(randomRow, 0) + builder.appendFrom(nullRow, 0) } - val accessor = ColumnAccessor(builder.build()) + val accessor = TestNullableColumnAccessor(builder.build(), columnType) val row = new GenericMutableRow(1) (0 until 4).foreach { _ => accessor.extractTo(row, 0) - assert(row(0) === nonNullRandomRow(columnType.typeId)) + assert(row(0) === randomRow(0)) accessor.extractTo(row, 0) - assert(row(0) === null) + assert(row.isNullAt(0)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 5222a47e1ab87..d9d1e1bfddb75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -19,63 +19,71 @@ package org.apache.spark.sql.columnar import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkSqlSerializer +class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) + extends BasicColumnBuilder[T, JvmType](new NoopColumnStats[T, JvmType], columnType) + with NullableColumnBuilder + +object TestNullableColumnBuilder { + def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = { + val builder = new TestNullableColumnBuilder(columnType) + builder.initialize(initialSize) + builder + } +} + class NullableColumnBuilderSuite extends FunSuite { - import ColumnarTestData._ + import ColumnarTestUtils._ Seq(INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, BINARY, GENERIC).foreach { testNullableColumnBuilder(_) } def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { - val columnBuilder = ColumnBuilder(columnType.typeId) val typeName = columnType.getClass.getSimpleName.stripSuffix("$") test(s"$typeName column builder: empty column") { - columnBuilder.initialize(4) - + val columnBuilder = TestNullableColumnBuilder(columnType) val buffer = columnBuilder.build() - // For column type ID - assert(buffer.getInt() === columnType.typeId) - // For null count - assert(buffer.getInt === 0) + expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) + expectResult(0, "Wrong null count")(buffer.getInt()) assert(!buffer.hasRemaining) } test(s"$typeName column builder: buffer size auto growth") { - columnBuilder.initialize(4) + val columnBuilder = TestNullableColumnBuilder(columnType) + val randomRow = makeRandomRow(columnType) - (0 until 4) foreach { _ => - columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId) + (0 until 4).foreach { _ => + columnBuilder.appendFrom(randomRow, 0) } val buffer = columnBuilder.build() - // For column type ID - assert(buffer.getInt() === columnType.typeId) - // For null count - assert(buffer.getInt() === 0) + expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) + expectResult(0, "Wrong null count")(buffer.getInt()) } test(s"$typeName column builder: null values") { - columnBuilder.initialize(4) + val columnBuilder = TestNullableColumnBuilder(columnType) + val randomRow = makeRandomRow(columnType) + val nullRow = makeNullRow(1) - (0 until 4) foreach { _ => - columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId) - columnBuilder.appendFrom(nullRow, columnType.typeId) + (0 until 4).foreach { _ => + columnBuilder.appendFrom(randomRow, 0) + columnBuilder.appendFrom(nullRow, 0) } val buffer = columnBuilder.build() - // For column type ID - assert(buffer.getInt() === columnType.typeId) - // For null count - assert(buffer.getInt() === 4) + expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) + expectResult(4, "Wrong null count")(buffer.getInt()) + // For null positions - (1 to 7 by 2).foreach(i => assert(buffer.getInt() === i)) + (1 to 7 by 2).foreach(expectResult(_, "Wrong null position")(buffer.getInt())) // For non-null values (0 until 4).foreach { _ => @@ -84,7 +92,8 @@ class NullableColumnBuilderSuite extends FunSuite { } else { columnType.extract(buffer) } - assert(actual === nonNullRandomRow(columnType.typeId)) + + assert(actual === randomRow(0), "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala new file mode 100644 index 0000000000000..a754f98f7fbf1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/BooleanBitSetSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import org.scalatest.FunSuite + +import org.apache.spark.sql.Row +import org.apache.spark.sql.columnar.{BOOLEAN, BooleanColumnStats} +import org.apache.spark.sql.columnar.ColumnarTestUtils._ + +class BooleanBitSetSuite extends FunSuite { + import BooleanBitSet._ + + def skeleton(count: Int) { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(new BooleanColumnStats, BOOLEAN, BooleanBitSet) + val rows = Seq.fill[Row](count)(makeRandomRow(BOOLEAN)) + val values = rows.map(_.head) + + rows.foreach(builder.appendFrom(_, 0)) + val buffer = builder.build() + + // Column type ID + null count + null positions + val headerSize = CompressionScheme.columnHeaderSize(buffer) + + // Compression scheme ID + element count + bitset words + val compressedSize = 4 + 4 + { + val extra = if (count % BITS_PER_LONG == 0) 0 else 1 + (count / BITS_PER_LONG + extra) * 8 + } + + // 4 extra bytes for compression scheme type ID + expectResult(headerSize + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + // Skips column header + buffer.position(headerSize) + expectResult(BooleanBitSet.typeId, "Wrong compression scheme ID")(buffer.getInt()) + expectResult(count, "Wrong element count")(buffer.getInt()) + + var word = 0: Long + for (i <- 0 until count) { + val bit = i % BITS_PER_LONG + word = if (bit == 0) buffer.getLong() else word + expectResult(values(i), s"Wrong value in compressed buffer, index=$i") { + (word & ((1: Long) << bit)) != 0 + } + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) + values.foreach(expectResult(_, "Wrong decoded value")(decoder.next())) + assert(!decoder.hasNext) + } + + test(s"$BooleanBitSet: empty") { + skeleton(0) + } + + test(s"$BooleanBitSet: less than 1 word") { + skeleton(BITS_PER_LONG - 1) + } + + test(s"$BooleanBitSet: exactly 1 word") { + skeleton(BITS_PER_LONG) + } + + test(s"$BooleanBitSet: multiple whole words") { + skeleton(BITS_PER_LONG * 2) + } + + test(s"$BooleanBitSet: multiple words and 1 more bit") { + skeleton(BITS_PER_LONG * 2 + 1) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala new file mode 100644 index 0000000000000..eab27987e08ea --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import java.nio.ByteBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.columnar.ColumnarTestUtils._ + +class DictionaryEncodingSuite extends FunSuite { + testDictionaryEncoding(new IntColumnStats, INT) + testDictionaryEncoding(new LongColumnStats, LONG) + testDictionaryEncoding(new StringColumnStats, STRING) + + def testDictionaryEncoding[T <: NativeType]( + columnStats: NativeColumnStats[T], + columnType: NativeColumnType[T]) { + + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + + def buildDictionary(buffer: ByteBuffer) = { + (0 until buffer.getInt()).map(columnType.extract(buffer) -> _.toShort).toMap + } + + def stableDistinct(seq: Seq[Int]): Seq[Int] = if (seq.isEmpty) { + Seq.empty + } else { + seq.head +: seq.tail.filterNot(_ == seq.head) + } + + def skeleton(uniqueValueCount: Int, inputSeq: Seq[Int]) { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val dictValues = stableDistinct(inputSeq) + + inputSeq.foreach(i => builder.appendFrom(rows(i), 0)) + + if (dictValues.length > DictionaryEncoding.MAX_DICT_SIZE) { + withClue("Dictionary overflowed, compression should fail") { + intercept[Throwable] { + builder.build() + } + } + } else { + val buffer = builder.build() + val headerSize = CompressionScheme.columnHeaderSize(buffer) + // 4 extra bytes for dictionary size + val dictionarySize = 4 + values.map(columnType.actualSize).sum + // 2 bytes for each `Short` + val compressedSize = 4 + dictionarySize + 2 * inputSeq.length + // 4 extra bytes for compression scheme type ID + expectResult(headerSize + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + // Skips column header + buffer.position(headerSize) + expectResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val dictionary = buildDictionary(buffer).toMap + + dictValues.foreach { i => + expectResult(i, "Wrong dictionary entry") { + dictionary(values(i)) + } + } + + inputSeq.foreach { i => + expectResult(i.toShort, "Wrong column element value")(buffer.getShort()) + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = DictionaryEncoding.decoder(buffer, columnType) + + inputSeq.foreach { i => + expectResult(values(i), "Wrong decoded value")(decoder.next()) + } + + assert(!decoder.hasNext) + } + } + + test(s"$DictionaryEncoding with $typeName: empty") { + skeleton(0, Seq.empty) + } + + test(s"$DictionaryEncoding with $typeName: simple case") { + skeleton(2, Seq(0, 1, 0, 1)) + } + + test(s"$DictionaryEncoding with $typeName: dictionary overflow") { + skeleton(DictionaryEncoding.MAX_DICT_SIZE + 1, 0 to DictionaryEncoding.MAX_DICT_SIZE) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala new file mode 100644 index 0000000000000..1390e5eef6106 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/IntegralDeltaSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow +import org.apache.spark.sql.catalyst.types.IntegralType +import org.apache.spark.sql.columnar._ + +class IntegralDeltaSuite extends FunSuite { + testIntegralDelta(new IntColumnStats, INT, IntDelta) + testIntegralDelta(new LongColumnStats, LONG, LongDelta) + + def testIntegralDelta[I <: IntegralType]( + columnStats: NativeColumnStats[I], + columnType: NativeColumnType[I], + scheme: IntegralDelta[I]) { + + def skeleton(input: Seq[I#JvmType]) { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, scheme) + val deltas = if (input.isEmpty) { + Seq.empty[Long] + } else { + (input.tail, input.init).zipped.map { + case (x: Int, y: Int) => (x - y).toLong + case (x: Long, y: Long) => x - y + } + } + + input.map { value => + val row = new GenericMutableRow(1) + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + + val buffer = builder.build() + // Column type ID + null count + null positions + val headerSize = CompressionScheme.columnHeaderSize(buffer) + + // Compression scheme ID + compressed contents + val compressedSize = 4 + (if (deltas.isEmpty) { + 0 + } else { + val oneBoolean = columnType.defaultSize + 1 + oneBoolean + deltas.map { + d => if (math.abs(d) < Byte.MaxValue) 1 else 1 + oneBoolean + }.sum + }) + + // 4 extra bytes for compression scheme type ID + expectResult(headerSize + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + buffer.position(headerSize) + expectResult(scheme.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + if (input.nonEmpty) { + expectResult(Byte.MinValue, "The first byte should be an escaping mark")(buffer.get()) + expectResult(input.head, "The first value is wrong")(columnType.extract(buffer)) + + (input.tail, deltas).zipped.foreach { (value, delta) => + if (delta < Byte.MaxValue) { + expectResult(delta, "Wrong delta")(buffer.get()) + } else { + expectResult(Byte.MinValue, "Expecting escaping mark here")(buffer.get()) + expectResult(value, "Wrong value")(columnType.extract(buffer)) + } + } + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = scheme.decoder(buffer, columnType) + input.foreach(expectResult(_, "Wrong decoded value")(decoder.next())) + assert(!decoder.hasNext) + } + + test(s"$scheme: empty column") { + skeleton(Seq.empty) + } + + test(s"$scheme: simple case") { + val input = columnType match { + case INT => Seq(1: Int, 2: Int, 130: Int) + case LONG => Seq(1: Long, 2: Long, 130: Long) + } + + skeleton(input.map(_.asInstanceOf[I#JvmType])) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala new file mode 100644 index 0000000000000..89f9b60a4397b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.columnar.ColumnarTestUtils._ + +class RunLengthEncodingSuite extends FunSuite { + testRunLengthEncoding(new BooleanColumnStats, BOOLEAN) + testRunLengthEncoding(new ByteColumnStats, BYTE) + testRunLengthEncoding(new ShortColumnStats, SHORT) + testRunLengthEncoding(new IntColumnStats, INT) + testRunLengthEncoding(new LongColumnStats, LONG) + testRunLengthEncoding(new StringColumnStats, STRING) + + def testRunLengthEncoding[T <: NativeType]( + columnStats: NativeColumnStats[T], + columnType: NativeColumnType[T]) { + + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + + def skeleton(uniqueValueCount: Int, inputRuns: Seq[(Int, Int)]) { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val inputSeq = inputRuns.flatMap { case (index, run) => + Seq.fill(run)(index) + } + + inputSeq.foreach(i => builder.appendFrom(rows(i), 0)) + val buffer = builder.build() + + // Column type ID + null count + null positions + val headerSize = CompressionScheme.columnHeaderSize(buffer) + + // Compression scheme ID + compressed contents + val compressedSize = 4 + inputRuns.map { case (index, _) => + // 4 extra bytes each run for run length + columnType.actualSize(values(index)) + 4 + }.sum + + // 4 extra bytes for compression scheme type ID + expectResult(headerSize + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + // Skips column header + buffer.position(headerSize) + expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + inputRuns.foreach { case (index, run) => + expectResult(values(index), "Wrong column element value")(columnType.extract(buffer)) + expectResult(run, "Wrong run length")(buffer.getInt()) + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = RunLengthEncoding.decoder(buffer, columnType) + + inputSeq.foreach { i => + expectResult(values(i), "Wrong decoded value")(decoder.next()) + } + + assert(!decoder.hasNext) + } + + test(s"$RunLengthEncoding with $typeName: empty column") { + skeleton(0, Seq.empty) + } + + test(s"$RunLengthEncoding with $typeName: simple case") { + skeleton(2, Seq(0 -> 2, 1 ->2)) + } + + test(s"$RunLengthEncoding with $typeName: run length == 1") { + skeleton(2, Seq(0 -> 1, 1 ->1)) + } + + test(s"$RunLengthEncoding with $typeName: single long run") { + skeleton(1, Seq(0 -> 1000)) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala new file mode 100644 index 0000000000000..81bf5e99d19b9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/TestCompressibleColumnBuilder.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar.compression + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar._ + +class TestCompressibleColumnBuilder[T <: NativeType]( + override val columnStats: NativeColumnStats[T], + override val columnType: NativeColumnType[T], + override val schemes: Seq[CompressionScheme]) + extends NativeColumnBuilder(columnStats, columnType) + with NullableColumnBuilder + with CompressibleColumnBuilder[T] { + + override protected def isWorthCompressing(encoder: Encoder[T]) = true +} + +object TestCompressibleColumnBuilder { + def apply[T <: NativeType]( + columnStats: NativeColumnStats[T], + columnType: NativeColumnType[T], + scheme: CompressionScheme) = { + + val builder = new TestCompressibleColumnBuilder(columnStats, columnType, Seq(scheme)) + builder.initialize(0) + builder + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala index ca5c8b8eb63dc..e55648b8ed15a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala @@ -39,9 +39,9 @@ case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generato val Seq(nameAttr, ageAttr) = input - override def apply(input: Row): TraversableOnce[Row] = { - val name = nameAttr.apply(input) - val age = ageAttr.apply(input).asInstanceOf[Int] + override def eval(input: Row): TraversableOnce[Row] = { + val name = nameAttr.eval(input) + val age = ageAttr.eval(input).asInstanceOf[Int] Iterator( new GenericRow(Array[Any](s"$name is $age years old")), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index ea1733b3614e5..fc68d6c5620d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -19,32 +19,45 @@ package org.apache.spark.sql.parquet import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{Path, FileSystem} import org.apache.hadoop.mapreduce.Job + import parquet.hadoop.ParquetFileWriter -import parquet.hadoop.util.ContextUtil import parquet.schema.MessageTypeParser +import parquet.hadoop.util.ContextUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.util.Utils +import org.apache.spark.sql.catalyst.types.{StringType, IntegerType, DataType} +import org.apache.spark.sql.{parquet, SchemaRDD} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import scala.Tuple2 // Implicits import org.apache.spark.sql.test.TestSQLContext._ +case class TestRDDEntry(key: Int, value: String) + class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll { + + var testRDD: SchemaRDD = null + override def beforeAll() { ParquetTestData.writeFile() + testRDD = parquetFile(ParquetTestData.testDir.toString) + testRDD.registerAsTable("testsource") } override def afterAll() { - ParquetTestData.testFile.delete() + Utils.deleteRecursively(ParquetTestData.testDir) + // here we should also unregister the table?? } test("self-join parquet files") { - val x = ParquetTestData.testData.subquery('x) - val y = ParquetTestData.testData.subquery('y) + val x = ParquetTestData.testData.as('x) + val y = ParquetTestData.testData.as('y) val query = x.join(y).where("x.myint".attr === "y.myint".attr) // Check to make sure that the attributes from either side of the join have unique expression @@ -55,11 +68,18 @@ class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll { case Seq(_, _) => // All good } - // TODO: We can't run this query as it NPEs + val result = query.collect() + assert(result.size === 9, "self-join result has incorrect size") + assert(result(0).size === 12, "result row has incorrect size") + result.zipWithIndex.foreach { + case (row, index) => row.zipWithIndex.foreach { + case (field, column) => assert(field != null, s"self-join contains null value in row $index field $column") + } + } } test("Import of simple Parquet file") { - val result = getRDD(ParquetTestData.testData).collect() + val result = parquetFile(ParquetTestData.testDir.toString).collect() assert(result.size === 15) result.zipWithIndex.foreach { case (row, index) => { @@ -125,20 +145,82 @@ class ParquetQuerySuite extends FunSuite with BeforeAndAfterAll { fs.delete(path, true) } + test("Creating case class RDD table") { + TestSQLContext.sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + .registerAsTable("tmp") + val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) + var counter = 1 + rdd.foreach { + // '===' does not like string comparison? + row: Row => { + assert(row.getString(1).equals(s"val_$counter"), s"row $counter value ${row.getString(1)} does not match val_$counter") + counter = counter + 1 + } + } + } + + test("Saving case class RDD table to file and reading it back in") { + val file = getTempFilePath("parquet") + val path = file.toString + val rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + rdd.saveAsParquetFile(path) + val readFile = parquetFile(path) + readFile.registerAsTable("tmpx") + val rdd_copy = sql("SELECT * FROM tmpx").collect() + val rdd_orig = rdd.collect() + for(i <- 0 to 99) { + assert(rdd_copy(i).apply(0) === rdd_orig(i).key, s"key error in line $i") + assert(rdd_copy(i).apply(1) === rdd_orig(i).value, s"value in line $i") + } + Utils.deleteRecursively(file) + assert(true) + } + + test("insert (overwrite) via Scala API (new SchemaRDD)") { + val dirname = Utils.createTempDir() + val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) + .map(i => TestRDDEntry(i, s"val_$i")) + source_rdd.registerAsTable("source") + val dest_rdd = createParquetFile(dirname.toString, ("key", IntegerType), ("value", StringType)) + dest_rdd.registerAsTable("dest") + sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() + val rdd_copy1 = sql("SELECT * FROM dest").collect() + assert(rdd_copy1.size === 100) + assert(rdd_copy1(0).apply(0) === 1) + assert(rdd_copy1(0).apply(1) === "val_1") + sql("INSERT INTO dest SELECT * FROM source").collect() + val rdd_copy2 = sql("SELECT * FROM dest").collect() + assert(rdd_copy2.size === 200) + Utils.deleteRecursively(dirname) + } + + test("insert (appending) to same table via Scala API") { + sql("INSERT INTO testsource SELECT * FROM testsource").collect() + val double_rdd = sql("SELECT * FROM testsource").collect() + assert(double_rdd != null) + assert(double_rdd.size === 30) + for(i <- (0 to 14)) { + assert(double_rdd(i) === double_rdd(i+15), s"error: lines $i and ${i+15} to not match") + } + // let's restore the original test data + Utils.deleteRecursively(ParquetTestData.testDir) + ParquetTestData.writeFile() + } + /** - * Computes the given [[ParquetRelation]] and returns its RDD. + * Creates an empty SchemaRDD backed by a ParquetRelation. * - * @param parquetRelation The Parquet relation. - * @return An RDD of Rows. + * TODO: since this is so experimental it is better to have it here and not + * in SQLContext. Also note that when creating new AttributeReferences + * one needs to take care not to create duplicate Attribute ID's. */ - private def getRDD(parquetRelation: ParquetRelation): RDD[Row] = { - val scanner = new ParquetTableScan( - parquetRelation.output, - parquetRelation, - None)(TestSQLContext.sparkContext) - scanner - .execute - .map(_.copy()) + private def createParquetFile(path: String, schema: (Tuple2[String, DataType])*): SchemaRDD = { + val attributes = schema.map(t => new AttributeReference(t._1, t._2)()) + new SchemaRDD( + TestSQLContext, + parquet.ParquetRelation.createEmpty(path, attributes, sparkContext.hadoopConfiguration)) } } diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 63f592cb4b441..a662da76ce25a 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -63,6 +63,10 @@ hive-exec ${hive.version} + + org.codehaus.jackson + jackson-mapper-asl + org.apache.hive hive-serde @@ -87,6 +91,30 @@ org.scalatest scalatest-maven-plugin + + + + org.apache.maven.plugins + maven-dependency-plugin + 2.4 + + + copy-dependencies + package + + copy-dependencies + + + + ${basedir}/../../lib_managed/jars + false + false + true + org.datanucleus + + + + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 197b557cba5f4..353458432b210 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -67,10 +67,24 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { class HiveContext(sc: SparkContext) extends SQLContext(sc) { self => - override def parseSql(sql: String): LogicalPlan = HiveQl.parseSql(sql) - override def executePlan(plan: LogicalPlan): this.QueryExecution = + override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = new this.QueryExecution { val logical = plan } + /** + * Executes a query expressed in HiveQL using Spark, returning the result as a SchemaRDD. + */ + def hiveql(hqlQuery: String): SchemaRDD = { + val result = new SchemaRDD(this, HiveQl.parseSql(hqlQuery)) + // We force query optimization to happen right away instead of letting it happen lazily like + // when using the query DSL. This is so DDL commands behave as expected. This is only + // generates the RDD lineage for DML queries, but do not perform any execution. + result.queryExecution.toRdd + result + } + + /** An alias for `hiveql`. */ + def hql(hqlQuery: String): SchemaRDD = hiveql(hqlQuery) + // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. @transient protected val outputBuffer = new java.io.OutputStream { @@ -108,7 +122,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient - override lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog { + override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog { override def lookupRelation( databaseName: Option[String], tableName: String, @@ -120,7 +134,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { /* An analyzer that uses the Hive metastore. */ @transient - override lazy val analyzer = new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) + override protected[sql] lazy val analyzer = + new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) /** * Runs the specified SQL query using Hive. @@ -188,7 +203,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hiveContext = self override val strategies: Seq[Strategy] = Seq( - TopK, + TakeOrdered, ParquetOperations, HiveTableScans, DataSinks, @@ -202,14 +217,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } @transient - override val planner = hivePlanner + override protected[sql] val planner = hivePlanner @transient protected lazy val emptyResult = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) /** Extends QueryExecution with hive specific features. */ - abstract class QueryExecution extends super.QueryExecution { + protected[sql] abstract class QueryExecution extends super.QueryExecution { // TODO: Create mixin for the analyzer instead of overriding things here. override lazy val optimizedPlan = optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) @@ -282,5 +297,11 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val asString = result.map(_.zip(types).map(toHiveString)).map(_.mkString("\t")).toSeq asString } + + override def simpleString: String = + logical match { + case _: NativeCommand => "" + case _ => executedPlan.toString + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 4f8353666a12b..fc053c56c052d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -141,6 +141,15 @@ class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with Logging { */ override def registerTable( databaseName: Option[String], tableName: String, plan: LogicalPlan): Unit = ??? + + /** + * UNIMPLEMENTED: It needs to be decided how we will persist in-memory tables to the metastore. + * For now, if this functionality is desired mix in the in-memory [[OverrideCatalog]]. + */ + override def unregisterTable( + databaseName: Option[String], tableName: String): Unit = ??? + + override def unregisterAllTables() = {} } object HiveMetastoreTypes extends RegexParsers { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 490a592a588d0..4dac25b3f60e4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -300,14 +300,17 @@ object HiveQl { } protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_BIGINT", Nil) => IntegerType + case Token("TOK_DECIMAL", Nil) => DecimalType + case Token("TOK_BIGINT", Nil) => LongType case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => IntegerType - case Token("TOK_SMALLINT", Nil) => IntegerType + case Token("TOK_TINYINT", Nil) => ByteType + case Token("TOK_SMALLINT", Nil) => ShortType case Token("TOK_BOOLEAN", Nil) => BooleanType case Token("TOK_STRING", Nil) => StringType case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => FloatType + case Token("TOK_DOUBLE", Nil) => DoubleType + case Token("TOK_TIMESTAMP", Nil) => TimestampType + case Token("TOK_BINARY", Nil) => BinaryType case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) => @@ -529,7 +532,7 @@ object HiveQl { val withLimit = limitClause.map(l => nodeToExpr(l.getChildren.head)) - .map(StopAfter(_, withSort)) + .map(Limit(_, withSort)) .getOrElse(withSort) // TOK_INSERT_INTO means to add files to the table. @@ -602,7 +605,7 @@ object HiveQl { case Token("TOK_TABLESPLITSAMPLE", Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) => - StopAfter(Literal(count.toInt), relation) + Limit(Literal(count.toInt), relation) case Token("TOK_TABLESPLITSAMPLE", Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => @@ -829,6 +832,8 @@ object HiveQl { Cast(nodeToExpr(arg), BooleanType) case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => Cast(nodeToExpr(arg), DecimalType) + case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), TimestampType) /* Arithmetic */ case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index bc3447b9d802d..2fea9702954d7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -110,10 +110,10 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r - class SqlQueryExecution(sql: String) extends this.QueryExecution { - lazy val logical = HiveQl.parseSql(sql) - def hiveExec() = runSqlHive(sql) - override def toString = sql + "\n" + super.toString + protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution { + lazy val logical = HiveQl.parseSql(hql) + def hiveExec() = runSqlHive(hql) + override def toString = hql + "\n" + super.toString } /** @@ -140,8 +140,8 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { case class TestTable(name: String, commands: (()=>Unit)*) - implicit class SqlCmd(sql: String) { - def cmd = () => new SqlQueryExecution(sql).stringResult(): Unit + protected[hive] implicit class SqlCmd(sql: String) { + def cmd = () => new HiveQLQueryExecution(sql).stringResult(): Unit } /** @@ -313,6 +313,8 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) { catalog.client.dropDatabase(db, true, false, true) } + catalog.unregisterAllTables() + FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala new file mode 100644 index 0000000000000..6df76fa825101 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/api/java/JavaHiveContext.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.api.java + +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.api.java.{JavaSQLContext, JavaSchemaRDD} +import org.apache.spark.sql.hive.{HiveContext, HiveQl} + +/** + * The entry point for executing Spark SQL queries from a Java program. + */ +class JavaHiveContext(sparkContext: JavaSparkContext) extends JavaSQLContext(sparkContext) { + + override val sqlContext = new HiveContext(sparkContext) + + /** + * Executes a query expressed in HiveQL, returning the result as a JavaSchemaRDD. + */ + def hql(hqlQuery: String): JavaSchemaRDD = { + val result = new JavaSchemaRDD(sqlContext, HiveQl.parseSql(hqlQuery)) + // We force query optimization to happen right away instead of letting it happen lazily like + // when using the query DSL. This is so DDL commands behave as expected. This is only + // generates the RDD lineage for DML queries, but do not perform any execution. + result.queryExecution.toRdd + result + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala index e2d9d8de2572a..821fb22112f87 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala @@ -106,7 +106,7 @@ case class HiveTableScan( } private def castFromString(value: String, dataType: DataType) = { - Cast(Literal(value), dataType).apply(null) + Cast(Literal(value), dataType).eval(null) } @transient @@ -134,7 +134,7 @@ case class HiveTableScan( // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. val row = new GenericRow(castedValues.toArray) - shouldKeep.apply(row).asInstanceOf[Boolean] + shouldKeep.eval(row).asInstanceOf[Boolean] } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 44901db3f963b..f9b437d435eba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -190,8 +190,8 @@ case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUd } // TODO: Finish input output types. - override def apply(input: Row): Any = { - val evaluatedChildren = children.map(_.apply(input)) + override def eval(input: Row): Any = { + val evaluatedChildren = children.map(_.eval(input)) // Wrap the function arguments in the expected types. val args = evaluatedChildren.zip(wrappers).map { case (arg, wrapper) => wrapper(arg) @@ -216,12 +216,12 @@ case class HiveGenericUdf( val dataType: DataType = inspectorToDataType(returnInspector) - override def apply(input: Row): Any = { + override def eval(input: Row): Any = { returnInspector // Make sure initialized. val args = children.map { v => new DeferredObject { override def prepare(i: Int) = {} - override def get(): AnyRef = wrap(v.apply(input)) + override def get(): AnyRef = wrap(v.eval(input)) } }.toArray unwrap(function.evaluate(args)) @@ -337,13 +337,16 @@ case class HiveGenericUdaf( type UDFType = AbstractGenericUDAFResolver + @transient protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name) + @transient protected lazy val objectInspector = { resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) } + @transient protected lazy val inspectors = children.map(_.dataType).map(toInspector) def dataType: DataType = inspectorToDataType(objectInspector) @@ -403,7 +406,7 @@ case class HiveGenericUdtf( } } - override def apply(input: Row): TraversableOnce[Row] = { + override def eval(input: Row): TraversableOnce[Row] = { outputInspectors // Make sure initialized. val inputProjection = new Projection(children) @@ -457,7 +460,7 @@ case class HiveUdafFunction( private val buffer = function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer] - override def apply(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) + override def eval(input: Row): Any = unwrapData(function.evaluate(buffer), returnInspector) @transient val inputProjection = new Projection(exprs) diff --git a/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d b/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d deleted file mode 100644 index 5f4de85940513..0000000000000 --- a/sql/hive/src/test/resources/golden/alias.*-0-7bdb861d11e895aaea545810cdac316d +++ /dev/null @@ -1 +0,0 @@ -0 val_0 \ No newline at end of file diff --git a/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d b/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d new file mode 100644 index 0000000000000..016f64cc26f2a --- /dev/null +++ b/sql/hive/src/test/resources/golden/alias.star-0-7bdb861d11e895aaea545810cdac316d @@ -0,0 +1 @@ +0 val_0 diff --git a/sql/hive/src/test/resources/golden/insert1-0-7faa9807151781e4207103aa568e321c b/sql/hive/src/test/resources/golden/insert1-0-7faa9807151781e4207103aa568e321c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-1-91d7b05c9024bff60b55f415cbeacc8b b/sql/hive/src/test/resources/golden/insert1-1-91d7b05c9024bff60b55f415cbeacc8b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-10-64f83491a8fe675ef3a4a9a474ac0439 b/sql/hive/src/test/resources/golden/insert1-10-64f83491a8fe675ef3a4a9a474ac0439 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-11-6f2797b6f81943d3b53b8d247ae8512b b/sql/hive/src/test/resources/golden/insert1-11-6f2797b6f81943d3b53b8d247ae8512b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-12-7a3c0a3f06484c912b9e951d8a2d8ac6 b/sql/hive/src/test/resources/golden/insert1-12-7a3c0a3f06484c912b9e951d8a2d8ac6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-13-42b03f938894fdafc7fff640711a9b2f b/sql/hive/src/test/resources/golden/insert1-13-42b03f938894fdafc7fff640711a9b2f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-14-e021dfb28597811870c03b3242972927 b/sql/hive/src/test/resources/golden/insert1-14-e021dfb28597811870c03b3242972927 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-15-c7fca497a4580b54a0a13b3b72da5d7c b/sql/hive/src/test/resources/golden/insert1-15-c7fca497a4580b54a0a13b3b72da5d7c new file mode 100644 index 0000000000000..5be49cad9a8ba --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1-15-c7fca497a4580b54a0a13b3b72da5d7c @@ -0,0 +1,2 @@ +db2_insert1 +db2_insert2 diff --git a/sql/hive/src/test/resources/golden/insert1-16-7a9e67189d3d4151f23b12c22bde06b5 b/sql/hive/src/test/resources/golden/insert1-16-7a9e67189d3d4151f23b12c22bde06b5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-17-5528e36b3b0f5b14313898cc45f9c23a b/sql/hive/src/test/resources/golden/insert1-17-5528e36b3b0f5b14313898cc45f9c23a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-18-16d78fba2d86277bc2f804037cc0a8b4 b/sql/hive/src/test/resources/golden/insert1-18-16d78fba2d86277bc2f804037cc0a8b4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-19-62518ff6810db9cdd8926702192a206b b/sql/hive/src/test/resources/golden/insert1-19-62518ff6810db9cdd8926702192a206b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-2-3f1de4475930285c3fdbe3a5ccd4e868 b/sql/hive/src/test/resources/golden/insert1-2-3f1de4475930285c3fdbe3a5ccd4e868 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-20-f4dc51ad64bb8662d066a8b9003da3d4 b/sql/hive/src/test/resources/golden/insert1-20-f4dc51ad64bb8662d066a8b9003da3d4 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-21-bb7624250ab556f2d40bfb8d419be487 b/sql/hive/src/test/resources/golden/insert1-21-bb7624250ab556f2d40bfb8d419be487 new file mode 100644 index 0000000000000..1e3637ebc6af2 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert1-21-bb7624250ab556f2d40bfb8d419be487 @@ -0,0 +1,2 @@ +db1_insert1 +db1_insert2 diff --git a/sql/hive/src/test/resources/golden/insert1-3-89f8a028e32fae213b575b4df4e26e9c b/sql/hive/src/test/resources/golden/insert1-3-89f8a028e32fae213b575b4df4e26e9c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-4-c7a68c0884785d0f5e62b287eb305d64 b/sql/hive/src/test/resources/golden/insert1-4-c7a68c0884785d0f5e62b287eb305d64 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-5-cb87ee12092fdf05daed82485c32a285 b/sql/hive/src/test/resources/golden/insert1-5-cb87ee12092fdf05daed82485c32a285 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-6-b97ba93a2c9ae671ecfc4fa95c024dda b/sql/hive/src/test/resources/golden/insert1-6-b97ba93a2c9ae671ecfc4fa95c024dda new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-7-a2cd0615b9e79befd9c1842516150a61 b/sql/hive/src/test/resources/golden/insert1-7-a2cd0615b9e79befd9c1842516150a61 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-8-5942e331621fe522fc297844046d2370 b/sql/hive/src/test/resources/golden/insert1-8-5942e331621fe522fc297844046d2370 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert1-9-5c5132707d7a4fb6e6a3de1a6719721a b/sql/hive/src/test/resources/golden/insert1-9-5c5132707d7a4fb6e6a3de1a6719721a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-0-5528e36b3b0f5b14313898cc45f9c23a b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-0-5528e36b3b0f5b14313898cc45f9c23a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-1-deb504f4f70fd7db975950c3c47959ee b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-1-deb504f4f70fd7db975950c3c47959ee new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-10-fda2e4be738186c0938f92d5072df55a b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-10-fda2e4be738186c0938f92d5072df55a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-11-9fb177236623d1b62acff28507033436 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-11-9fb177236623d1b62acff28507033436 new file mode 100644 index 0000000000000..01f2b7063f91b --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-11-9fb177236623d1b62acff28507033436 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +98 val_98 +97 val_97 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-12-99d5ad32bb81640cb284312841b60000 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-12-99d5ad32bb81640cb284312841b60000 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-13-9dda06e1aae1860bd19eee97703a8217 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-13-9dda06e1aae1860bd19eee97703a8217 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-14-19daabdd4c0d403c8781967248d09c53 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-14-19daabdd4c0d403c8781967248d09c53 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-15-812006e1f11e005e5029866d1cf004f6 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-15-812006e1f11e005e5029866d1cf004f6 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-2-bd042746328158822a25d711ffed18dd b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-2-bd042746328158822a25d711ffed18dd new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-3-b7aaedd7d624af4e48637ff1acabe485 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-3-b7aaedd7d624af4e48637ff1acabe485 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-4-dece2650bf0615e566cd6c84181ce026 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-4-dece2650bf0615e566cd6c84181ce026 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-5-1eb5c694e5a02aa292e24a0849350108 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-5-1eb5c694e5a02aa292e24a0849350108 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-6-ab49e0665a80a6b34dadc96f1d18ce26 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-6-ab49e0665a80a6b34dadc96f1d18ce26 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-7-fda2e4be738186c0938f92d5072df55a b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-7-fda2e4be738186c0938f92d5072df55a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-8-9fb177236623d1b62acff28507033436 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-8-9fb177236623d1b62acff28507033436 new file mode 100644 index 0000000000000..01f2b7063f91b --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-8-9fb177236623d1b62acff28507033436 @@ -0,0 +1,5 @@ +98 val_98 +98 val_98 +98 val_98 +97 val_97 +97 val_97 diff --git a/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-9-ab49e0665a80a6b34dadc96f1d18ce26 b/sql/hive/src/test/resources/golden/insert2_overwrite_partitions-9-ab49e0665a80a6b34dadc96f1d18ce26 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/load_binary_data-0-491edd0c42ceb79e799ba50555bc8c15 b/sql/hive/src/test/resources/golden/load_binary_data-0-491edd0c42ceb79e799ba50555bc8c15 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/load_binary_data-1-5d72f8449b69df3c08e3f444f09428bc b/sql/hive/src/test/resources/golden/load_binary_data-1-5d72f8449b69df3c08e3f444f09428bc new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/load_binary_data-2-242b1655c7e7325ee9f26552ea8fc25 b/sql/hive/src/test/resources/golden/load_binary_data-2-242b1655c7e7325ee9f26552ea8fc25 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/load_binary_data-3-2a72df8d3e398d0963ef91162ce7d268 b/sql/hive/src/test/resources/golden/load_binary_data-3-2a72df8d3e398d0963ef91162ce7d268 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b new file mode 100644 index 0000000000000..60878ffb77064 --- /dev/null +++ b/sql/hive/src/test/resources/golden/read from cached table-0-ce3797dc14a603cba2a5e58c8612de5b @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b new file mode 100644 index 0000000000000..60878ffb77064 --- /dev/null +++ b/sql/hive/src/test/resources/golden/read from uncached table-0-ce3797dc14a603cba2a5e58c8612de5b @@ -0,0 +1 @@ +238 val_238 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-0-86a409d8b868dc5f1a3bd1e04c2bc28c b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-0-86a409d8b868dc5f1a3bd1e04c2bc28c new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-0-86a409d8b868dc5f1a3bd1e04c2bc28c @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-1-2b1df88619e34f221d39598b5cd73283 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-1-2b1df88619e34f221d39598b5cd73283 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-1-2b1df88619e34f221d39598b5cd73283 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-10-60eadbb52f8857830a3034952c631ace b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-10-60eadbb52f8857830a3034952c631ace new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-11-dbe79f90862dc5c6cc4a4fa4b4b6c655 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-11-dbe79f90862dc5c6cc4a4fa4b4b6c655 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-12-60018cae9a0476dc6a0ab4264310edb5 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-12-60018cae9a0476dc6a0ab4264310edb5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-2-7562d4fee13f3ba935a2e824f86a4224 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-2-7562d4fee13f3ba935a2e824f86a4224 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-2-7562d4fee13f3ba935a2e824f86a4224 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-3-bdb30a5d6887ee4fb089f8676313eafd b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-3-bdb30a5d6887ee4fb089f8676313eafd new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-3-bdb30a5d6887ee4fb089f8676313eafd @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-4-10713b30ecb3c88acdd775bf9628c38c b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-4-10713b30ecb3c88acdd775bf9628c38c new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-4-10713b30ecb3c88acdd775bf9628c38c @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-5-bab89dfffa77258e34a595e0e79986e3 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-5-bab89dfffa77258e34a595e0e79986e3 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-5-bab89dfffa77258e34a595e0e79986e3 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-6-6f53d5613262d393d82d159ec5dc16dc b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-6-6f53d5613262d393d82d159ec5dc16dc new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-6-6f53d5613262d393d82d159ec5dc16dc @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-7-ad4ddb5c5d6b994f4dba35f6162b6a9f b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-7-ad4ddb5c5d6b994f4dba35f6162b6a9f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-8-f9dd797f1c90e2108cfee585f443c132 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-8-f9dd797f1c90e2108cfee585f443c132 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook-9-22fdd8380f2652de2492b34a425d46d7 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook-9-22fdd8380f2652de2492b34a425d46d7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-0-7a9e67189d3d4151f23b12c22bde06b5 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-0-7a9e67189d3d4151f23b12c22bde06b5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-1-86a409d8b868dc5f1a3bd1e04c2bc28c b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-1-86a409d8b868dc5f1a3bd1e04c2bc28c new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-1-86a409d8b868dc5f1a3bd1e04c2bc28c @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-10-22fdd8380f2652de2492b34a425d46d7 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-10-22fdd8380f2652de2492b34a425d46d7 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-11-60eadbb52f8857830a3034952c631ace b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-11-60eadbb52f8857830a3034952c631ace new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-12-dbe79f90862dc5c6cc4a4fa4b4b6c655 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-12-dbe79f90862dc5c6cc4a4fa4b4b6c655 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-13-60018cae9a0476dc6a0ab4264310edb5 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-13-60018cae9a0476dc6a0ab4264310edb5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-2-2b1df88619e34f221d39598b5cd73283 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-2-2b1df88619e34f221d39598b5cd73283 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-2-2b1df88619e34f221d39598b5cd73283 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-3-7562d4fee13f3ba935a2e824f86a4224 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-3-7562d4fee13f3ba935a2e824f86a4224 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-3-7562d4fee13f3ba935a2e824f86a4224 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-4-bdb30a5d6887ee4fb089f8676313eafd b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-4-bdb30a5d6887ee4fb089f8676313eafd new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-4-bdb30a5d6887ee4fb089f8676313eafd @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-5-10713b30ecb3c88acdd775bf9628c38c b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-5-10713b30ecb3c88acdd775bf9628c38c new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-5-10713b30ecb3c88acdd775bf9628c38c @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-6-bab89dfffa77258e34a595e0e79986e3 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-6-bab89dfffa77258e34a595e0e79986e3 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-6-bab89dfffa77258e34a595e0e79986e3 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-7-6f53d5613262d393d82d159ec5dc16dc b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-7-6f53d5613262d393d82d159ec5dc16dc new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-7-6f53d5613262d393d82d159ec5dc16dc @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-8-7a45282169e5a15d70ae0afb9e67ec9a b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-8-7a45282169e5a15d70ae0afb9e67ec9a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-9-f9dd797f1c90e2108cfee585f443c132 b/sql/hive/src/test/resources/golden/sample_islocalmode_hook_hadoop20-9-f9dd797f1c90e2108cfee585f443c132 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-0-48751533b44ea9e8ac3131767c2fed05 b/sql/hive/src/test/resources/golden/timestamp_comparison-0-48751533b44ea9e8ac3131767c2fed05 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_comparison-0-48751533b44ea9e8ac3131767c2fed05 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-1-60557e7bd2822c89fa8b076a9d0520fc b/sql/hive/src/test/resources/golden/timestamp_comparison-1-60557e7bd2822c89fa8b076a9d0520fc new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_comparison-1-60557e7bd2822c89fa8b076a9d0520fc @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-2-f96a9d88327951bd93f672dc2463ecd4 b/sql/hive/src/test/resources/golden/timestamp_comparison-2-f96a9d88327951bd93f672dc2463ecd4 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_comparison-2-f96a9d88327951bd93f672dc2463ecd4 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-3-13e17ed811165196416f777cbc162592 b/sql/hive/src/test/resources/golden/timestamp_comparison-3-13e17ed811165196416f777cbc162592 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_comparison-3-13e17ed811165196416f777cbc162592 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-4-4fa8a36edbefde4427c2ab2cf30e6399 b/sql/hive/src/test/resources/golden/timestamp_comparison-4-4fa8a36edbefde4427c2ab2cf30e6399 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_comparison-4-4fa8a36edbefde4427c2ab2cf30e6399 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-5-7e4fb6e8ba01df422e4c67e06a0c8453 b/sql/hive/src/test/resources/golden/timestamp_comparison-5-7e4fb6e8ba01df422e4c67e06a0c8453 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_comparison-5-7e4fb6e8ba01df422e4c67e06a0c8453 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-6-8c8e73673a950f6b3d960b08fcea076f b/sql/hive/src/test/resources/golden/timestamp_comparison-6-8c8e73673a950f6b3d960b08fcea076f new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_comparison-6-8c8e73673a950f6b3d960b08fcea076f @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-7-510c0a2a57dc5df8588bd13c4152f8bc b/sql/hive/src/test/resources/golden/timestamp_comparison-7-510c0a2a57dc5df8588bd13c4152f8bc new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_comparison-7-510c0a2a57dc5df8588bd13c4152f8bc @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/timestamp_comparison-8-659d5b1ae8200f13f265270e52a3dd65 b/sql/hive/src/test/resources/golden/timestamp_comparison-8-659d5b1ae8200f13f265270e52a3dd65 new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/timestamp_comparison-8-659d5b1ae8200f13f265270e52a3dd65 @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/type_cast_1-0-60ea21e6e7d054a65f959fc89acf1b3d b/sql/hive/src/test/resources/golden/type_cast_1-0-60ea21e6e7d054a65f959fc89acf1b3d new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/type_cast_1-1-53a667981ad567b2ab977f67d65c5825 b/sql/hive/src/test/resources/golden/type_cast_1-1-53a667981ad567b2ab977f67d65c5825 new file mode 100644 index 0000000000000..7ed6ff82de6bc --- /dev/null +++ b/sql/hive/src/test/resources/golden/type_cast_1-1-53a667981ad567b2ab977f67d65c5825 @@ -0,0 +1 @@ +5 diff --git a/sql/hive/src/test/resources/golden/udf_printf-0-e86d559aeb84a4cc017a103182c22bfb b/sql/hive/src/test/resources/golden/udf_printf-0-e86d559aeb84a4cc017a103182c22bfb new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_printf-1-19c61fce27310ab2590062d643f7b26e b/sql/hive/src/test/resources/golden/udf_printf-1-19c61fce27310ab2590062d643f7b26e new file mode 100644 index 0000000000000..1635ff88dd768 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_printf-1-19c61fce27310ab2590062d643f7b26e @@ -0,0 +1 @@ +printf(String format, Obj... args) - function that can format strings according to printf-style format strings diff --git a/sql/hive/src/test/resources/golden/udf_printf-2-25aa6950cae2bb781c336378f63ceaee b/sql/hive/src/test/resources/golden/udf_printf-2-25aa6950cae2bb781c336378f63ceaee new file mode 100644 index 0000000000000..62440ee68e145 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_printf-2-25aa6950cae2bb781c336378f63ceaee @@ -0,0 +1,4 @@ +printf(String format, Obj... args) - function that can format strings according to printf-style format strings +Example: + > SELECT printf("Hello World %d %s", 100, "days")FROM src LIMIT 1; + "Hello World 100 days" diff --git a/sql/hive/src/test/resources/golden/udf_printf-3-9c568a0473888396bd46507e8b330c36 b/sql/hive/src/test/resources/golden/udf_printf-3-9c568a0473888396bd46507e8b330c36 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_printf-4-91728e546b450bdcbb05ef30f13be475 b/sql/hive/src/test/resources/golden/udf_printf-4-91728e546b450bdcbb05ef30f13be475 new file mode 100644 index 0000000000000..39cb945991403 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_printf-4-91728e546b450bdcbb05ef30f13be475 @@ -0,0 +1 @@ +Hello World 100 days diff --git a/sql/hive/src/test/resources/golden/udf_printf-5-3141a0421605b091ee5a9e99d7d605fb b/sql/hive/src/test/resources/golden/udf_printf-5-3141a0421605b091ee5a9e99d7d605fb new file mode 100644 index 0000000000000..04bf5e552a576 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_printf-5-3141a0421605b091ee5a9e99d7d605fb @@ -0,0 +1 @@ +All Type Test: false, A, 15000, 1.234000e+01, +27183.2401, 2300.41, 32, corret, 0x1.002p8 diff --git a/sql/hive/src/test/resources/golden/udf_printf-6-ec37b73012f3cbbbc0422744b0db8294 b/sql/hive/src/test/resources/golden/udf_printf-6-ec37b73012f3cbbbc0422744b0db8294 new file mode 100644 index 0000000000000..2e9f7509968a3 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_printf-6-ec37b73012f3cbbbc0422744b0db8294 @@ -0,0 +1 @@ +Color red, String Null: null, number1 123456, number2 00089, Integer Null: null, hex 0xff, float 3.14 Double Null: null diff --git a/sql/hive/src/test/resources/golden/udf_printf-7-5769f3a5b3300ca1d8b861229e976126 b/sql/hive/src/test/resources/golden/udf_printf-7-5769f3a5b3300ca1d8b861229e976126 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-10-51822ac740629bebd81d2abda6e1144 b/sql/hive/src/test/resources/golden/udf_to_boolean-10-51822ac740629bebd81d2abda6e1144 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-10-51822ac740629bebd81d2abda6e1144 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-11-441306cae24618c49ec63445a31bf16b b/sql/hive/src/test/resources/golden/udf_to_boolean-11-441306cae24618c49ec63445a31bf16b new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-11-441306cae24618c49ec63445a31bf16b @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-12-bfcc534e73e320a1cfad9c584678d870 b/sql/hive/src/test/resources/golden/udf_to_boolean-12-bfcc534e73e320a1cfad9c584678d870 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-12-bfcc534e73e320a1cfad9c584678d870 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-13-a2bddaa5db1841bb4617239b9f17a06d b/sql/hive/src/test/resources/golden/udf_to_boolean-13-a2bddaa5db1841bb4617239b9f17a06d new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-13-a2bddaa5db1841bb4617239b9f17a06d @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-14-773801b833cf72d35016916b786275b5 b/sql/hive/src/test/resources/golden/udf_to_boolean-14-773801b833cf72d35016916b786275b5 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-14-773801b833cf72d35016916b786275b5 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-15-4071ed0ff57b53963d5ee662fa9db0b0 b/sql/hive/src/test/resources/golden/udf_to_boolean-15-4071ed0ff57b53963d5ee662fa9db0b0 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-15-4071ed0ff57b53963d5ee662fa9db0b0 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-16-6b441df08afdc0c6c4a82670997dabb5 b/sql/hive/src/test/resources/golden/udf_to_boolean-16-6b441df08afdc0c6c4a82670997dabb5 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-16-6b441df08afdc0c6c4a82670997dabb5 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-17-85342c694d7f35e7eedb24e850d0c7df b/sql/hive/src/test/resources/golden/udf_to_boolean-17-85342c694d7f35e7eedb24e850d0c7df new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-17-85342c694d7f35e7eedb24e850d0c7df @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-18-fcd7af0e71d3e2d934239ba606e3ed87 b/sql/hive/src/test/resources/golden/udf_to_boolean-18-fcd7af0e71d3e2d934239ba606e3ed87 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-18-fcd7af0e71d3e2d934239ba606e3ed87 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-19-dcdb12fe551aa68a56921822f5d1a343 b/sql/hive/src/test/resources/golden/udf_to_boolean-19-dcdb12fe551aa68a56921822f5d1a343 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-19-dcdb12fe551aa68a56921822f5d1a343 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-20-131900d39d9a20b431731a32fb9715f8 b/sql/hive/src/test/resources/golden/udf_to_boolean-20-131900d39d9a20b431731a32fb9715f8 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-20-131900d39d9a20b431731a32fb9715f8 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-21-a5e28f4eb819e5a5e292e279f2990a7a b/sql/hive/src/test/resources/golden/udf_to_boolean-21-a5e28f4eb819e5a5e292e279f2990a7a new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-21-a5e28f4eb819e5a5e292e279f2990a7a @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-22-93278c10d642fa242f303d89b3b1961d b/sql/hive/src/test/resources/golden/udf_to_boolean-22-93278c10d642fa242f303d89b3b1961d new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-22-93278c10d642fa242f303d89b3b1961d @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-23-828558020ce907ffa7e847762a5e2358 b/sql/hive/src/test/resources/golden/udf_to_boolean-23-828558020ce907ffa7e847762a5e2358 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-23-828558020ce907ffa7e847762a5e2358 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-24-e8ca597d87932af16c0cf29d662e92da b/sql/hive/src/test/resources/golden/udf_to_boolean-24-e8ca597d87932af16c0cf29d662e92da new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-24-e8ca597d87932af16c0cf29d662e92da @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-25-86245727f90de9ce65a12c97a03a5635 b/sql/hive/src/test/resources/golden/udf_to_boolean-25-86245727f90de9ce65a12c97a03a5635 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-25-86245727f90de9ce65a12c97a03a5635 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-26-552d7ec5a4e0c93dc59a61973e2d63a2 b/sql/hive/src/test/resources/golden/udf_to_boolean-26-552d7ec5a4e0c93dc59a61973e2d63a2 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-26-552d7ec5a4e0c93dc59a61973e2d63a2 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-27-b61509b01b2fe3e7e4b72fedc74ff4f9 b/sql/hive/src/test/resources/golden/udf_to_boolean-27-b61509b01b2fe3e7e4b72fedc74ff4f9 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-27-b61509b01b2fe3e7e4b72fedc74ff4f9 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-8-37229f303635a030f6cab20e0381f51f b/sql/hive/src/test/resources/golden/udf_to_boolean-8-37229f303635a030f6cab20e0381f51f new file mode 100644 index 0000000000000..27ba77ddaf615 --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-8-37229f303635a030f6cab20e0381f51f @@ -0,0 +1 @@ +true diff --git a/sql/hive/src/test/resources/golden/udf_to_boolean-9-be623247e4dbf119b43458b72d1be017 b/sql/hive/src/test/resources/golden/udf_to_boolean-9-be623247e4dbf119b43458b72d1be017 new file mode 100644 index 0000000000000..c508d5366f70b --- /dev/null +++ b/sql/hive/src/test/resources/golden/udf_to_boolean-9-be623247e4dbf119b43458b72d1be017 @@ -0,0 +1 @@ +false diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala new file mode 100644 index 0000000000000..79ec1f1cde019 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.execution.SparkLogicalPlan +import org.apache.spark.sql.columnar.InMemoryColumnarTableScan +import org.apache.spark.sql.hive.execution.HiveComparisonTest + +class CachedTableSuite extends HiveComparisonTest { + TestHive.loadTestTable("src") + + test("cache table") { + TestHive.cacheTable("src") + } + + createQueryTest("read from cached table", + "SELECT * FROM src LIMIT 1", reset = false) + + test("check that table is cached and uncache") { + TestHive.table("src").queryExecution.analyzed match { + case SparkLogicalPlan(_ : InMemoryColumnarTableScan) => // Found evidence of caching + case noCache => fail(s"No cache node found in plan $noCache") + } + TestHive.uncacheTable("src") + } + + createQueryTest("read from uncached table", + "SELECT * FROM src LIMIT 1", reset = false) + + test("make sure table is uncached") { + TestHive.table("src").queryExecution.analyzed match { + case cachePlan @ SparkLogicalPlan(_ : InMemoryColumnarTableScan) => + fail(s"Table still cached after uncache: $cachePlan") + case noCache => // Table uncached successfully + } + } + + test("correct error on uncache of non-cached table") { + intercept[IllegalArgumentException] { + TestHive.uncacheTable("src") + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala new file mode 100644 index 0000000000000..8137f99b227f4 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveSuite.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.api.java + +import org.scalatest.FunSuite + +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.hive.TestHive + +// Implicits +import scala.collection.JavaConversions._ + +class JavaHiveSQLSuite extends FunSuite { + ignore("SELECT * FROM src") { + val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) + // There is a little trickery here to avoid instantiating two HiveContexts in the same JVM + val javaSqlCtx = new JavaHiveContext(javaCtx) { + override val sqlContext = TestHive + } + + assert( + javaSqlCtx.hql("SELECT * FROM src").collect().map(_.getInt(0)) === + TestHive.sql("SELECT * FROM src").collect().map(_.getInt(0)).toSeq) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index c7a350ef94edd..3cc4562a88d66 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -125,7 +125,7 @@ abstract class HiveComparisonTest } protected def prepareAnswer( - hiveQuery: TestHive.type#SqlQueryExecution, + hiveQuery: TestHive.type#HiveQLQueryExecution, answer: Seq[String]): Seq[String] = { val orderedAnswer = hiveQuery.logical match { // Clean out non-deterministic time schema info. @@ -170,7 +170,7 @@ abstract class HiveComparisonTest } val installHooksCommand = "(?i)SET.*hooks".r - def createQueryTest(testCaseName: String, sql: String) { + def createQueryTest(testCaseName: String, sql: String, reset: Boolean = true) { // If test sharding is enable, skip tests that are not in the correct shard. shardInfo.foreach { case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return @@ -227,8 +227,8 @@ abstract class HiveComparisonTest try { // MINOR HACK: You must run a query before calling reset the first time. - TestHive.sql("SHOW TABLES") - TestHive.reset() + TestHive.hql("SHOW TABLES") + if (reset) { TestHive.reset() } val hiveCacheFiles = queryList.zipWithIndex.map { case (queryString, i) => @@ -256,7 +256,7 @@ abstract class HiveComparisonTest hiveCachedResults } else { - val hiveQueries = queryList.map(new TestHive.SqlQueryExecution(_)) + val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. hiveQueries.foreach(_.logical) val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { @@ -295,14 +295,14 @@ abstract class HiveComparisonTest fail(errorMessage) } }.toSeq - TestHive.reset() + if (reset) { TestHive.reset() } computedResults } // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.SqlQueryExecution(queryString) + val query = new TestHive.HiveQLQueryExecution(queryString) try { (query, prepareAnswer(query, query.stringResult())) } catch { case e: Exception => val errorMessage = @@ -359,7 +359,7 @@ abstract class HiveComparisonTest // When we encounter an error we check to see if the environment is still okay by running a simple query. // If this fails then we halt testing since something must have gone seriously wrong. try { - new TestHive.SqlQueryExecution("SELECT key FROM src").stringResult() + new TestHive.HiveQLQueryExecution("SELECT key FROM src").stringResult() TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f74b0fbb97c83..f76e16bc1afc5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -42,6 +42,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "bucket_num_reducers", "column_access_stats", "concatenate_inherit_table_location", + "describe_pretty", + "describe_syntax", + "orc_ends_with_nulls", // Setting a default property does not seem to get reset and thus changes the answer for many // subsequent tests. @@ -80,7 +83,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "index_auto_update", "index_auto_self_join", "index_stale.*", - "type_cast_1", "index_compression", "index_bitmap_compression", "index_auto_multiple", @@ -237,9 +239,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "compute_stats_binary", "compute_stats_boolean", "compute_stats_double", - "compute_stats_table", + "compute_stats_empty_table", "compute_stats_long", "compute_stats_string", + "compute_stats_table", "convert_enum_to_string", "correlationoptimizer11", "correlationoptimizer15", @@ -266,8 +269,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "desc_non_existent_tbl", "describe_comment_indent", "describe_database_json", - "describe_pretty", - "describe_syntax", + "describe_formatted_view_partitioned", + "describe_formatted_view_partitioned_json", "describe_table_json", "diff_part_input_formats", "disable_file_format_check", @@ -339,8 +342,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "input11_limit", "input12", "input12_hadoop20", + "input14", "input19", "input1_limit", + "input21", "input22", "input23", "input24", @@ -355,6 +360,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "input7", "input8", "input9", + "inputddl4", + "inputddl7", + "inputddl8", "input_limit", "input_part0", "input_part1", @@ -368,9 +376,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "input_part7", "input_part8", "input_part9", - "inputddl4", - "inputddl7", - "inputddl8", + "input_testsequencefile", + "insert1", + "insert2_overwrite_partitions", "insert_compressed", "join0", "join1", @@ -385,6 +393,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "join17", "join18", "join19", + "join_1to1", "join2", "join20", "join21", @@ -400,6 +409,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "join30", "join31", "join32", + "join32_lessSize", "join33", "join34", "join35", @@ -415,13 +425,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "join7", "join8", "join9", - "join_1to1", "join_array", "join_casesensitive", "join_empty", "join_filters", "join_hive_626", + "join_map_ppr", "join_nulls", + "join_rc", "join_reorder2", "join_reorder3", "join_reorder4", @@ -435,22 +446,32 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "literal_string", "load_dyn_part7", "load_file_with_space_in_the_name", + "loadpart1", "louter_join_ppr", "mapjoin_distinct", "mapjoin_mapjoin", "mapjoin_subquery", "mapjoin_subquery2", "mapjoin_test_outer", + "mapreduce1", + "mapreduce2", "mapreduce3", + "mapreduce4", + "mapreduce5", + "mapreduce6", "mapreduce7", + "mapreduce8", "merge1", "merge2", "mergejoins", "mergejoins_mixed", + "multigroupby_singlemr", + "multi_insert_gby", + "multi_insert_gby3", + "multi_insert_lateral_view", + "multi_join_union", "multiMapJoin1", "multiMapJoin2", - "multi_join_union", - "multigroupby_singlemr", "noalias_subq1", "nomore_ambiguous_table_col", "nonblock_op_deduplicate", @@ -466,16 +487,30 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "nullinput2", "nullscript", "optional_outer", + "orc_dictionary_threshold", + "orc_empty_files", "order", "order2", "outer_join_ppr", + "parallel", + "parenthesis_star_by", + "partcols1", "part_inherit_tbl_props", "part_inherit_tbl_props_empty", "part_inherit_tbl_props_with_star", "partition_schema1", + "partition_serde_format", "partition_varchar1", + "partition_wise_fileformat4", + "partition_wise_fileformat5", + "partition_wise_fileformat6", + "partition_wise_fileformat7", + "partition_wise_fileformat9", "plan_json", "ppd1", + "ppd2", + "ppd_clusterby", + "ppd_constant_expr", "ppd_constant_where", "ppd_gby", "ppd_gby2", @@ -491,6 +526,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "ppd_outer_join5", "ppd_random", "ppd_repeated_alias", + "ppd_transform", "ppd_udf_col", "ppd_union", "ppr_allchildsarenull", @@ -503,7 +539,15 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "query_with_semi", "quote1", "quote2", + "rcfile_columnar", + "rcfile_lazydecompress", + "rcfile_null_value", + "rcfile_toleratecorruptions", + "rcfile_union", + "reduce_deduplicate", + "reduce_deduplicate_exclude_gby", "reduce_deduplicate_exclude_join", + "reducesink_dedup", "rename_column", "router_join_ppr", "select_as_omitted", @@ -531,6 +575,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "smb_mapjoin_3", "smb_mapjoin_4", "smb_mapjoin_5", + "smb_mapjoin_6", + "smb_mapjoin_7", "smb_mapjoin_8", "sort", "sort_merge_join_desc_1", @@ -541,21 +587,27 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "sort_merge_join_desc_6", "sort_merge_join_desc_7", "stats0", + "stats_aggregator_error_1", "stats_empty_partition", + "stats_publisher_error_1", "subq2", "tablename_with_select", + "timestamp_comparison", "touch", + "transform_ppr1", + "transform_ppr2", + "type_cast_1", "type_widening", "udaf_collect_set", "udaf_corr", "udaf_covar_pop", "udaf_covar_samp", + "udaf_histogram_numeric", + "udf_10_trims", "udf2", "udf6", + "udf8", "udf9", - "udf_10_trims", - "udf_E", - "udf_PI", "udf_abs", "udf_acos", "udf_add", @@ -585,13 +637,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "udf_cos", "udf_count", "udf_date_add", - "udf_date_sub", "udf_datediff", + "udf_date_sub", "udf_day", "udf_dayofmonth", "udf_degrees", "udf_div", "udf_double", + "udf_E", "udf_exp", "udf_field", "udf_find_in_set", @@ -631,6 +684,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "udf_nvl", "udf_or", "udf_parse_url", + "udf_PI", "udf_positive", "udf_pow", "udf_power", @@ -671,9 +725,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "udf_trim", "udf_ucase", "udf_upper", + "udf_variance", "udf_var_pop", "udf_var_samp", - "udf_variance", "udf_weekofyear", "udf_when", "udf_xpath", @@ -703,8 +757,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "union27", "union28", "union29", + "union3", "union30", "union31", + "union33", "union34", "union4", "union5", @@ -714,6 +770,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest { "union9", "union_lateralview", "union_ppr", + "union_remove_11", "union_remove_3", "union_remove_6", "union_script", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index c184ebe288af4..a09667ac84b01 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -23,6 +23,16 @@ import org.apache.spark.sql.hive.TestHive._ * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. */ class HiveQuerySuite extends HiveComparisonTest { + + test("Query expressed in SQL") { + assert(sql("SELECT 1").collect() === Array(Seq(1))) + } + + test("Query expressed in HiveQL") { + hql("FROM src SELECT key").collect() + hiveql("FROM src SELECT key").collect() + } + createQueryTest("Simple Average", "SELECT AVG(key) FROM src") @@ -133,7 +143,11 @@ class HiveQuerySuite extends HiveComparisonTest { "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") test("sampling") { - sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") + hql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") } + test("SchemaRDD toString") { + hql("SHOW TABLES").toString + hql("SELECT * FROM src").toString + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index d77900ddc950c..8883e5b16d4da 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -48,7 +48,7 @@ class HiveResolutionSuite extends HiveComparisonTest { createQueryTest("attr", "SELECT key FROM src a ORDER BY key LIMIT 1") - createQueryTest("alias.*", + createQueryTest("alias.star", "SELECT a.* FROM src a ORDER BY key LIMIT 1") test("case insensitivity with scala reflection") { @@ -56,7 +56,7 @@ class HiveResolutionSuite extends HiveComparisonTest { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2)) :: Nil) .registerAsTable("caseSensitivityTest") - sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") + hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 1318ac1968dad..d9ccb93e23923 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -136,7 +136,7 @@ class PruningSuite extends HiveComparisonTest { expectedScannedColumns: Seq[String], expectedPartValues: Seq[Seq[String]]) = { test(s"$testCaseName - pruning test") { - val plan = new TestHive.SqlQueryExecution(sql).executedPlan + val plan = new TestHive.HiveQLQueryExecution(sql).executedPlan val actualOutputColumns = plan.output.map(_.name) val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 05ad85b622ac8..aade62eb8f84e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -17,147 +17,138 @@ package org.apache.spark.sql.parquet -import java.io.File - import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.getTempFilePath +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row} +import org.apache.spark.sql.catalyst.types.{DataType, StringType, IntegerType} +import org.apache.spark.sql.{parquet, SchemaRDD} import org.apache.spark.sql.hive.TestHive +import org.apache.spark.util.Utils + +// Implicits +import org.apache.spark.sql.hive.TestHive._ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { - val filename = getTempFilePath("parquettest").getCanonicalFile.toURI.toString - - // runs a SQL and optionally resolves one Parquet table - def runQuery( - querystr: String, - tableName: Option[String] = None, - filename: Option[String] = None): Array[Row] = { - - // call to resolve references in order to get CREATE TABLE AS to work - val query = TestHive - .parseSql(querystr) - val finalQuery = - if (tableName.nonEmpty && filename.nonEmpty) - resolveParquetTable(tableName.get, filename.get, query) - else - query - TestHive.executePlan(finalQuery) - .toRdd - .collect() - } - // stores a query output to a Parquet file - def storeQuery(querystr: String, filename: String): Unit = { - val query = WriteToFile( - filename, - TestHive.parseSql(querystr)) - TestHive - .executePlan(query) - .stringResult() - } + val dirname = Utils.createTempDir() - /** - * TODO: This function is necessary as long as there is no notion of a Catalog for - * Parquet tables. Once such a thing exists this functionality should be moved there. - */ - def resolveParquetTable(tableName: String, filename: String, plan: LogicalPlan): LogicalPlan = { - TestHive.loadTestTable("src") // may not be loaded now - plan.transform { - case relation @ UnresolvedRelation(databaseName, name, alias) => - if (name == tableName) - ParquetRelation(tableName, filename) - else - relation - case op @ InsertIntoCreatedTable(databaseName, name, child) => - if (name == tableName) { - // note: at this stage the plan is not yet analyzed but Parquet needs to know the schema - // and for that we need the child to be resolved - val relation = ParquetRelation.create( - filename, - TestHive.analyzer(child), - TestHive.sparkContext.hadoopConfiguration, - Some(tableName)) - InsertIntoTable( - relation.asInstanceOf[BaseRelation], - Map.empty, - child, - overwrite = false) - } else - op - } - } + var testRDD: SchemaRDD = null override def beforeAll() { // write test data - ParquetTestData.writeFile() - // Override initial Parquet test table - TestHive.catalog.registerTable(Some[String]("parquet"), "testsource", ParquetTestData.testData) + ParquetTestData.writeFile + testRDD = parquetFile(ParquetTestData.testDir.toString) + testRDD.registerAsTable("testsource") } override def afterAll() { - ParquetTestData.testFile.delete() + Utils.deleteRecursively(ParquetTestData.testDir) + Utils.deleteRecursively(dirname) + reset() // drop all tables that were registered as part of the tests } + // in case tests are failing we delete before and after each test override def beforeEach() { - new File(filename).getAbsoluteFile.delete() + Utils.deleteRecursively(dirname) } override def afterEach() { - new File(filename).getAbsoluteFile.delete() + Utils.deleteRecursively(dirname) } test("SELECT on Parquet table") { - val rdd = runQuery("SELECT * FROM parquet.testsource") + val rdd = hql("SELECT * FROM testsource").collect() assert(rdd != null) assert(rdd.forall(_.size == 6)) } test("Simple column projection + filter on Parquet table") { - val rdd = runQuery("SELECT myboolean, mylong FROM parquet.testsource WHERE myboolean=true") + val rdd = hql("SELECT myboolean, mylong FROM testsource WHERE myboolean=true").collect() assert(rdd.size === 5, "Filter returned incorrect number of rows") assert(rdd.forall(_.getBoolean(0)), "Filter returned incorrect Boolean field value") } - test("Converting Hive to Parquet Table via WriteToFile") { - storeQuery("SELECT * FROM src", filename) - val rddOne = runQuery("SELECT * FROM src").sortBy(_.getInt(0)) - val rddTwo = runQuery("SELECT * from ptable", Some("ptable"), Some(filename)).sortBy(_.getInt(0)) + test("Converting Hive to Parquet Table via saveAsParquetFile") { + hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) + parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") + val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0)) + val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0)) compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) } test("INSERT OVERWRITE TABLE Parquet table") { - storeQuery("SELECT * FROM parquet.testsource", filename) - runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename)) - runQuery("INSERT OVERWRITE TABLE ptable SELECT * FROM parquet.testsource", Some("ptable"), Some(filename)) - val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename)) - val rddOrig = runQuery("SELECT * FROM parquet.testsource") - compareRDDs(rddOrig, rddCopy, "parquet.testsource", ParquetTestData.testSchemaFieldNames) + hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) + parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") + // let's do three overwrites for good measure + hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() + val rddCopy = hql("SELECT * FROM ptable").collect() + val rddOrig = hql("SELECT * FROM testsource").collect() + assert(rddCopy.size === rddOrig.size, "INSERT OVERWRITE changed size of table??") + compareRDDs(rddOrig, rddCopy, "testsource", ParquetTestData.testSchemaFieldNames) } - test("CREATE TABLE AS Parquet table") { - runQuery("CREATE TABLE ptable AS SELECT * FROM src", Some("ptable"), Some(filename)) - val rddCopy = runQuery("SELECT * FROM ptable", Some("ptable"), Some(filename)) + test("CREATE TABLE of Parquet table") { + createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType)) + .registerAsTable("tmp") + val rddCopy = + hql("INSERT INTO TABLE tmp SELECT * FROM src") + .collect() .sortBy[Int](_.apply(0) match { case x: Int => x case _ => 0 }) - val rddOrig = runQuery("SELECT * FROM src").sortBy(_.getInt(0)) + val rddOrig = hql("SELECT * FROM src") + .collect() + .sortBy(_.getInt(0)) compareRDDs(rddOrig, rddCopy, "src (Hive)", Seq("key:Int", "value:String")) } + test("Appending to Parquet table") { + createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType)) + .registerAsTable("tmpnew") + hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect() + hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect() + hql("INSERT INTO TABLE tmpnew SELECT * FROM src").collect() + val rddCopies = hql("SELECT * FROM tmpnew").collect() + val rddOrig = hql("SELECT * FROM src").collect() + assert(rddCopies.size === 3 * rddOrig.size, "number of copied rows via INSERT INTO did not match correct number") + } + + test("Appending to and then overwriting Parquet table") { + createParquetFile(dirname.getAbsolutePath, ("key", IntegerType), ("value", StringType)) + .registerAsTable("tmp") + hql("INSERT INTO TABLE tmp SELECT * FROM src").collect() + hql("INSERT INTO TABLE tmp SELECT * FROM src").collect() + hql("INSERT OVERWRITE TABLE tmp SELECT * FROM src").collect() + val rddCopies = hql("SELECT * FROM tmp").collect() + val rddOrig = hql("SELECT * FROM src").collect() + assert(rddCopies.size === rddOrig.size, "INSERT OVERWRITE did not actually overwrite") + } + private def compareRDDs(rddOne: Array[Row], rddTwo: Array[Row], tableName: String, fieldNames: Seq[String]) { var counter = 0 (rddOne, rddTwo).zipped.foreach { (a,b) => (a,b).zipped.toArray.zipWithIndex.foreach { - case ((value_1:Array[Byte], value_2:Array[Byte]), index) => - assert(new String(value_1) === new String(value_2), s"table $tableName row $counter field ${fieldNames(index)} don't match") case ((value_1, value_2), index) => assert(value_1 === value_2, s"table $tableName row $counter field ${fieldNames(index)} don't match") } counter = counter + 1 } } + + /** + * Creates an empty SchemaRDD backed by a ParquetRelation. + * + * TODO: since this is so experimental it is better to have it here and not + * in SQLContext. Also note that when creating new AttributeReferences + * one needs to take care not to create duplicate Attribute ID's. + */ + private def createParquetFile(path: String, schema: (Tuple2[String, DataType])*): SchemaRDD = { + val attributes = schema.map(t => new AttributeReference(t._1, t._2)()) + new SchemaRDD( + TestHive, + parquet.ParquetRelation.createEmpty(path, attributes, sparkContext.hadoopConfiguration)) + } } diff --git a/streaming/pom.xml b/streaming/pom.xml index 1953cc6883378..93b1c5a37aff9 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -96,7 +96,6 @@ org.apache.maven.plugins maven-jar-plugin - 2.2 diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index baf80fe2a91b7..ac56ff709c1c4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -59,7 +59,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) } } -private[streaming] +private[streaming] object Checkpoint extends Logging { val PREFIX = "checkpoint-" val REGEX = (PREFIX + """([\d]+)([\w\.]*)""").r @@ -79,7 +79,7 @@ object Checkpoint extends Logging { def sortFunc(path1: Path, path2: Path): Boolean = { val (time1, bk1) = path1.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } val (time2, bk2) = path2.getName match { case REGEX(x, y) => (x.toLong, !y.isEmpty) } - (time1 < time2) || (time1 == time2 && bk1) + (time1 < time2) || (time1 == time2 && bk1) } val path = new Path(checkpointDir) @@ -95,7 +95,7 @@ object Checkpoint extends Logging { } } else { logInfo("Checkpoint directory " + path + " does not exist") - Seq.empty + Seq.empty } } } @@ -160,7 +160,7 @@ class CheckpointWriter( }) } - // All done, print success + // All done, print success val finishTime = System.currentTimeMillis() logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + checkpointFile + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " ms") @@ -194,19 +194,19 @@ class CheckpointWriter( } } - def stop() { - synchronized { - if (stopped) { - return - } - stopped = true - } + def stop(): Unit = synchronized { + if (stopped) return + executor.shutdown() val startTime = System.currentTimeMillis() val terminated = executor.awaitTermination(10, java.util.concurrent.TimeUnit.SECONDS) + if (!terminated) { + executor.shutdownNow() + } val endTime = System.currentTimeMillis() logInfo("CheckpointWriter executor terminated ? " + terminated + ", waited for " + (endTime - startTime) + " ms.") + stopped = true } private def fs = synchronized { @@ -227,14 +227,14 @@ object CheckpointReader extends Logging { { val checkpointPath = new Path(checkpointDir) def fs = checkpointPath.getFileSystem(hadoopConf) - - // Try to find the checkpoint files + + // Try to find the checkpoint files val checkpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs).reverse if (checkpointFiles.isEmpty) { return None } - // Try to read the checkpoint files in the order + // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) val compressionCodec = CompressionCodec.createCodec(conf) checkpointFiles.foreach(file => { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala index 16479a01272aa..ad4f3fdd14ad6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Interval.scala @@ -20,11 +20,11 @@ package org.apache.spark.streaming private[streaming] class Interval(val beginTime: Time, val endTime: Time) { def this(beginMs: Long, endMs: Long) = this(new Time(beginMs), new Time(endMs)) - + def duration(): Duration = endTime - beginTime def + (time: Duration): Interval = { - new Interval(beginTime + time, endTime + time) + new Interval(beginTime + time, endTime + time) } def - (time: Duration): Interval = { @@ -40,9 +40,9 @@ class Interval(val beginTime: Time, val endTime: Time) { } def <= (that: Interval) = (this < that || this == that) - + def > (that: Interval) = !(this <= that) - + def >= (that: Interval) = !(this < that) override def toString = "[" + beginTime + ", " + endTime + "]" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 42a70ead7e40f..9a2e5201fdc3d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -17,30 +17,28 @@ package org.apache.spark.streaming -import scala.collection.mutable.Queue -import scala.collection.Map -import scala.reflect.ClassTag - import java.io.InputStream import java.util.concurrent.atomic.AtomicInteger -import akka.actor.Props -import akka.actor.SupervisorStrategy -import org.apache.hadoop.io.LongWritable -import org.apache.hadoop.io.Text +import scala.collection.Map +import scala.collection.mutable.Queue +import scala.reflect.ClassTag + +import akka.actor.{Props, SupervisorStrategy} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.hadoop.mapreduce.lib.input.TextInputFormat -import org.apache.hadoop.fs.Path - import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.MetadataCleaner import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.receiver.NetworkReceiver import org.apache.spark.streaming.receivers._ import org.apache.spark.streaming.scheduler._ -import org.apache.hadoop.conf.Configuration -import org.apache.spark.streaming.receiver.NetworkReceiver +import org.apache.spark.streaming.ui.StreamingTab +import org.apache.spark.util.MetadataCleaner /** * Main entry point for Spark Streaming functionality. It provides methods used to create @@ -159,6 +157,17 @@ class StreamingContext private[streaming] ( private[streaming] val waiter = new ContextWaiter + private[streaming] val uiTab = new StreamingTab(this) + + /** Enumeration to identify current state of the StreamingContext */ + private[streaming] object StreamingContextState extends Enumeration { + type CheckpointState = Value + val Initialized, Started, Stopped = Value + } + + import StreamingContextState._ + private[streaming] var state = Initialized + /** * Return the associated Spark context */ @@ -406,9 +415,18 @@ class StreamingContext private[streaming] ( /** * Start the execution of the streams. */ - def start() = synchronized { + def start(): Unit = synchronized { + // Throw exception if the context has already been started once + // or if a stopped context is being started again + if (state == Started) { + throw new SparkException("StreamingContext has already been started") + } + if (state == Stopped) { + throw new SparkException("StreamingContext has already been stopped") + } validate() scheduler.start() + state = Started } /** @@ -429,14 +447,38 @@ class StreamingContext private[streaming] ( } /** - * Stop the execution of the streams. + * Stop the execution of the streams immediately (does not wait for all received data + * to be processed). * @param stopSparkContext Stop the associated SparkContext or not + * */ def stop(stopSparkContext: Boolean = true): Unit = synchronized { - scheduler.stop() + stop(stopSparkContext, false) + } + + /** + * Stop the execution of the streams, with option of ensuring all received data + * has been processed. + * @param stopSparkContext Stop the associated SparkContext or not + * @param stopGracefully Stop gracefully by waiting for the processing of all + * received data to be completed + */ + def stop(stopSparkContext: Boolean, stopGracefully: Boolean): Unit = synchronized { + // Warn (but not fail) if context is stopped twice, + // or context is stopped before starting + if (state == Initialized) { + logWarning("StreamingContext has not been started yet") + return + } + if (state == Stopped) { + logWarning("StreamingContext has already been stopped") + return + } // no need to throw an exception as its okay to stop twice + scheduler.stop(stopGracefully) logInfo("StreamingContext stopped successfully") waiter.notifyStop() if (stopSparkContext) sc.stop() + state = Stopped } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala index 2678334f53844..6a6b00a778b48 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Time.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Time.scala @@ -32,7 +32,7 @@ case class Time(private val millis: Long) { def <= (that: Time): Boolean = (this.millis <= that.millis) def > (that: Time): Boolean = (this.millis > that.millis) - + def >= (that: Time): Boolean = (this.millis >= that.millis) def + (that: Duration): Time = new Time(millis + that.milliseconds) @@ -43,7 +43,7 @@ case class Time(private val millis: Long) { def floor(that: Duration): Time = { val t = that.milliseconds - val m = math.floor(this.millis / t).toLong + val m = math.floor(this.millis / t).toLong new Time(m * t) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index ac451d1913aaa..2ac943d7bf781 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.api.java -import java.lang.{Long => JLong} +import java.lang.{Long => JLong, Iterable => JIterable} import java.util.{List => JList} import scala.collection.JavaConversions._ @@ -115,15 +115,15 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. */ - def groupByKey(): JavaPairDStream[K, JList[V]] = - dstream.groupByKey().mapValues(seqAsJavaList _) + def groupByKey(): JavaPairDStream[K, JIterable[V]] = + dstream.groupByKey().mapValues(asJavaIterable _) /** * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ - def groupByKey(numPartitions: Int): JavaPairDStream[K, JList[V]] = - dstream.groupByKey(numPartitions).mapValues(seqAsJavaList _) + def groupByKey(numPartitions: Int): JavaPairDStream[K, JIterable[V]] = + dstream.groupByKey(numPartitions).mapValues(asJavaIterable _) /** * Return a new DStream by applying `groupByKey` on each RDD of `this` DStream. @@ -131,8 +131,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * single sequence to generate the RDDs of the new DStream. org.apache.spark.Partitioner * is used to control the partitioning of each RDD. */ - def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JList[V]] = - dstream.groupByKey(partitioner).mapValues(seqAsJavaList _) + def groupByKey(partitioner: Partitioner): JavaPairDStream[K, JIterable[V]] = + dstream.groupByKey(partitioner).mapValues(asJavaIterable _) /** * Return a new DStream by applying `reduceByKey` to each RDD. The values for each key are @@ -196,8 +196,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ - def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowDuration).mapValues(seqAsJavaList _) + def groupByKeyAndWindow(windowDuration: Duration): JavaPairDStream[K, JIterable[V]] = { + dstream.groupByKeyAndWindow(windowDuration).mapValues(asJavaIterable _) } /** @@ -211,8 +211,8 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * DStream's batching interval */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) - : JavaPairDStream[K, JList[V]] = { - dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(seqAsJavaList _) + : JavaPairDStream[K, JIterable[V]] = { + dstream.groupByKeyAndWindow(windowDuration, slideDuration).mapValues(asJavaIterable _) } /** @@ -227,9 +227,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @param numPartitions Number of partitions of each RDD in the new DStream. */ def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - :JavaPairDStream[K, JList[V]] = { + :JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, numPartitions) - .mapValues(seqAsJavaList _) + .mapValues(asJavaIterable _) } /** @@ -247,9 +247,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ):JavaPairDStream[K, JList[V]] = { + ):JavaPairDStream[K, JIterable[V]] = { dstream.groupByKeyAndWindow(windowDuration, slideDuration, partitioner) - .mapValues(seqAsJavaList _) + .mapValues(asJavaIterable _) } /** @@ -518,9 +518,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * Hash partitioning is used to generate the RDDs with Spark's default number * of partitions. */ - def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JList[V], JList[W])] = { + def cogroup[W](other: JavaPairDStream[K, W]): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag - dstream.cogroup(other.dstream).mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) + dstream.cogroup(other.dstream).mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) } /** @@ -530,10 +530,10 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( def cogroup[W]( other: JavaPairDStream[K, W], numPartitions: Int - ): JavaPairDStream[K, (JList[V], JList[W])] = { + ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag dstream.cogroup(other.dstream, numPartitions) - .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) + .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) } /** @@ -543,10 +543,10 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( def cogroup[W]( other: JavaPairDStream[K, W], partitioner: Partitioner - ): JavaPairDStream[K, (JList[V], JList[W])] = { + ): JavaPairDStream[K, (JIterable[V], JIterable[W])] = { implicit val cm: ClassTag[W] = fakeClassTag dstream.cogroup(other.dstream, partitioner) - .mapValues(t => (seqAsJavaList(t._1), seqAsJavaList((t._2)))) + .mapValues(t => (asJavaIterable(t._1), asJavaIterable((t._2)))) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index b705d2ec9a58e..c800602d0959b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -509,8 +509,16 @@ class JavaStreamingContext(val ssc: StreamingContext) { * Stop the execution of the streams. * @param stopSparkContext Stop the associated SparkContext or not */ - def stop(stopSparkContext: Boolean): Unit = { - ssc.stop(stopSparkContext) + def stop(stopSparkContext: Boolean) = ssc.stop(stopSparkContext) + + /** + * Stop the execution of the streams. + * @param stopSparkContext Stop the associated SparkContext or not + * @param stopGracefully Stop gracefully by waiting for the processing of all + * received data to be completed + */ + def stop(stopSparkContext: Boolean, stopGracefully: Boolean) = { + ssc.stop(stopSparkContext, stopGracefully) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index d48b51aa69565..a7e5215437e54 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -341,9 +341,11 @@ abstract class DStream[T: ClassTag] ( */ private[streaming] def clearMetadata(time: Time) { val oldRDDs = generatedRDDs.filter(_._1 <= (time - rememberDuration)) + logDebug("Clearing references to old RDDs: [" + + oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]") generatedRDDs --= oldRDDs.keys if (ssc.conf.getBoolean("spark.streaming.unpersist", false)) { - logDebug("Unpersisting old RDDs: " + oldRDDs.keys.mkString(", ")) + logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", ")) oldRDDs.values.foreach(_.unpersist(false)) } logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + @@ -351,15 +353,6 @@ abstract class DStream[T: ClassTag] ( dependencies.foreach(_.clearMetadata(time)) } - /* Adds metadata to the Stream while it is running. - * This method should be overwritten by sublcasses of InputDStream. - */ - private[streaming] def addMetadata(metadata: Any) { - if (metadata != null) { - logInfo("Dropping Metadata: " + metadata.toString) - } - } - /** * Refresh the list of checkpointed RDDs that will be saved along with checkpoint of * this stream. This is an internal method that should not be called directly. This is diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 903e3f3c9b713..f33c0ceafdf42 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -51,7 +51,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) .map(x => (x._1, x._2.getCheckpointFile.get)) logDebug("Current checkpoint files:\n" + checkpointFiles.toSeq.mkString("\n")) - // Add the checkpoint files to the data to be serialized + // Add the checkpoint files to the data to be serialized if (!checkpointFiles.isEmpty) { currentCheckpointFiles.clear() currentCheckpointFiles ++= checkpointFiles diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 8a6051622e2d5..e878285f6a854 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -232,7 +232,7 @@ class FileInputDStream[K: ClassTag, V: ClassTag, F <: NewInputFormat[K,V] : Clas } logDebug("Accepted " + path) } catch { - case fnfe: java.io.FileNotFoundException => + case fnfe: java.io.FileNotFoundException => logWarning("Error finding new files", fnfe) reset() return false diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 423c2ae72d691..0750ef4b3dfc2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -17,13 +17,14 @@ package org.apache.spark.streaming.dstream -import scala.Array +import scala.collection.mutable.HashMap import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId import org.apache.spark.streaming._ import org.apache.spark.streaming.receiver.NetworkReceiver +import org.apache.spark.streaming.scheduler.ReceivedBlockInfo /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] @@ -38,8 +39,10 @@ import org.apache.spark.streaming.receiver.NetworkReceiver abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingContext) extends InputDStream[T](ssc_) { - // This is an unique identifier that is used to match the network receiver with the - // corresponding network input stream. + /** Keeps all received blocks information */ + private lazy val receivedBlockInfo = new HashMap[Time, Array[ReceivedBlockInfo]] + + /** This is an unique identifier for the network input stream. */ val id = ssc.getNewNetworkStreamId() /** @@ -54,20 +57,38 @@ abstract class NetworkInputDStream[T: ClassTag](@transient ssc_ : StreamingConte def stop() {} + /** Ask NetworkInputTracker for received data blocks and generates RDDs with them. */ override def compute(validTime: Time): Option[RDD[T]] = { // If this is called for any time before the start time of the context, // then this returns an empty RDD. This may happen when recovering from a // master failure if (validTime >= graph.startTime) { - val blockIds = ssc.scheduler.networkInputTracker.getBlockIds(id, validTime) + val blockInfo = ssc.scheduler.networkInputTracker.getReceivedBlockInfo(id) + receivedBlockInfo(validTime) = blockInfo + val blockIds = blockInfo.map(_.blockId.asInstanceOf[BlockId]) Some(new BlockRDD[T](ssc.sc, blockIds)) } else { Some(new BlockRDD[T](ssc.sc, Array[BlockId]())) } } -} - - + /** Get information on received blocks. */ + private[streaming] def getReceivedBlockInfo(time: Time) = { + receivedBlockInfo(time) + } + /** + * Clear metadata that are older than `rememberDuration` of this DStream. + * This is an internal method that should not be called directly. This + * implementation overrides the default implementation to clear received + * block information. + */ + private[streaming] override def clearMetadata(time: Time) { + super.clearMetadata(time) + val oldReceivedBlocks = receivedBlockInfo.filter(_._1 <= (time - rememberDuration)) + receivedBlockInfo --= oldReceivedBlocks.keys + logDebug("Cleared " + oldReceivedBlocks.size + " RDDs that were older than " + + (time - rememberDuration) + ": " + oldReceivedBlocks.keys.mkString(", ")) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index 2473496949360..354bc132dcdc0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -51,7 +51,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with Spark's default number of partitions. */ - def groupByKey(): DStream[(K, Seq[V])] = { + def groupByKey(): DStream[(K, Iterable[V])] = { groupByKey(defaultPartitioner()) } @@ -59,7 +59,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` to each RDD. Hash partitioning is used to * generate the RDDs with `numPartitions` partitions. */ - def groupByKey(numPartitions: Int): DStream[(K, Seq[V])] = { + def groupByKey(numPartitions: Int): DStream[(K, Iterable[V])] = { groupByKey(defaultPartitioner(numPartitions)) } @@ -67,12 +67,12 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) * Return a new DStream by applying `groupByKey` on each RDD. The supplied * org.apache.spark.Partitioner is used to control the partitioning of each RDD. */ - def groupByKey(partitioner: Partitioner): DStream[(K, Seq[V])] = { + def groupByKey(partitioner: Partitioner): DStream[(K, Iterable[V])] = { val createCombiner = (v: V) => ArrayBuffer[V](v) val mergeValue = (c: ArrayBuffer[V], v: V) => (c += v) val mergeCombiner = (c1: ArrayBuffer[V], c2: ArrayBuffer[V]) => (c1 ++ c2) combineByKey(createCombiner, mergeValue, mergeCombiner, partitioner) - .asInstanceOf[DStream[(K, Seq[V])]] + .asInstanceOf[DStream[(K, Iterable[V])]] } /** @@ -126,7 +126,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) * @param windowDuration width of the window; must be a multiple of this DStream's * batching interval */ - def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Seq[V])] = { + def groupByKeyAndWindow(windowDuration: Duration): DStream[(K, Iterable[V])] = { groupByKeyAndWindow(windowDuration, self.slideDuration, defaultPartitioner()) } @@ -140,7 +140,8 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) * the new DStream will generate RDDs); must be a multiple of this * DStream's batching interval */ - def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration): DStream[(K, Seq[V])] = + def groupByKeyAndWindow(windowDuration: Duration, slideDuration: Duration) + : DStream[(K, Iterable[V])] = { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner()) } @@ -161,7 +162,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, numPartitions: Int - ): DStream[(K, Seq[V])] = { + ): DStream[(K, Iterable[V])] = { groupByKeyAndWindow(windowDuration, slideDuration, defaultPartitioner(numPartitions)) } @@ -180,14 +181,14 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) windowDuration: Duration, slideDuration: Duration, partitioner: Partitioner - ): DStream[(K, Seq[V])] = { - val createCombiner = (v: Seq[V]) => new ArrayBuffer[V] ++= v - val mergeValue = (buf: ArrayBuffer[V], v: Seq[V]) => buf ++= v + ): DStream[(K, Iterable[V])] = { + val createCombiner = (v: Iterable[V]) => new ArrayBuffer[V] ++= v + val mergeValue = (buf: ArrayBuffer[V], v: Iterable[V]) => buf ++= v val mergeCombiner = (buf1: ArrayBuffer[V], buf2: ArrayBuffer[V]) => buf1 ++= buf2 self.groupByKey(partitioner) .window(windowDuration, slideDuration) .combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, partitioner) - .asInstanceOf[DStream[(K, Seq[V])]] + .asInstanceOf[DStream[(K, Iterable[V])]] } /** @@ -438,7 +439,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) * Hash partitioning is used to generate the RDDs with Spark's default number * of partitions. */ - def cogroup[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Seq[V], Seq[W]))] = { + def cogroup[W: ClassTag](other: DStream[(K, W)]): DStream[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, defaultPartitioner()) } @@ -447,7 +448,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) * Hash partitioning is used to generate the RDDs with `numPartitions` partitions. */ def cogroup[W: ClassTag](other: DStream[(K, W)], numPartitions: Int) - : DStream[(K, (Seq[V], Seq[W]))] = { + : DStream[(K, (Iterable[V], Iterable[W]))] = { cogroup(other, defaultPartitioner(numPartitions)) } @@ -458,7 +459,7 @@ class PairDStreamFunctions[K: ClassTag, V: ClassTag](self: DStream[(K,V)]) def cogroup[W: ClassTag]( other: DStream[(K, W)], partitioner: Partitioner - ): DStream[(K, (Seq[V], Seq[W]))] = { + ): DStream[(K, (Iterable[V], Iterable[W]))] = { self.transformWith( other, (rdd1: RDD[(K, V)], rdd2: RDD[(K, W)]) => rdd1.cogroup(rdd2, partitioner) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index 97325f8ea3117..6376cff78b78a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -31,11 +31,11 @@ class QueueInputDStream[T: ClassTag]( oneAtATime: Boolean, defaultRDD: RDD[T] ) extends InputDStream[T](ssc) { - + override def start() { } - + override def stop() { } - + override def compute(validTime: Time): Option[RDD[T]] = { val buffer = new ArrayBuffer[RDD[T]]() if (oneAtATime && queue.size > 0) { @@ -55,5 +55,5 @@ class QueueInputDStream[T: ClassTag]( None } } - + } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index 701e4920ec9cc..731cb84cd45ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -65,7 +65,6 @@ class SocketReceiver[T: ClassTag]( def onStop() { if (socket != null) socket.close() } - } private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 5f7d3ba26c656..7e22268767de7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -56,9 +56,14 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // first map the cogrouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, (Seq[V], Seq[S]))]) => { + val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { val i = iterator.map(t => { - (t._1, t._2._1, t._2._2.headOption) + val itr = t._2._2.iterator + val headOption = itr.hasNext match { + case true => Some(itr.next()) + case false => None + } + (t._1, t._2._1.toSeq, headOption) }) updateFuncLocal(i) } @@ -90,8 +95,8 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // first map the grouped tuple to tuples of required type, // and then apply the update function val updateFuncLocal = updateFunc - val finalFunc = (iterator: Iterator[(K, Seq[V])]) => { - updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2, None))) + val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => { + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None))) } val groupedRDD = parentRDD.groupByKey(partitioner) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 24289b714f99e..775b6bfd065c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -32,7 +32,7 @@ class WindowedDStream[T: ClassTag]( extends DStream[T](parent.ssc) { if (!_windowDuration.isMultipleOf(parent.slideDuration)) { - throw new Exception("The window duration of windowed DStream (" + _slideDuration + ") " + + throw new Exception("The window duration of windowed DStream (" + _windowDuration + ") " + "must be a multiple of the slide duration of parent DStream (" + parent.slideDuration + ")") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index 5157e20927533..661c11f8de53c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -48,14 +48,16 @@ private[streaming] class BlockGenerator( private case class Block(id: StreamBlockId, buffer: ArrayBuffer[Any]) + private val clock = new SystemClock() private val blockInterval = conf.getLong("spark.streaming.blockInterval", 200) private val blockIntervalTimer = - new RecurringTimer(new SystemClock(), blockInterval, updateCurrentBuffer) + new RecurringTimer(clock, blockInterval, updateCurrentBuffer, + "BlockGenerator") private val blocksForPushing = new ArrayBlockingQueue[Block](10) private val blockPushingThread = new Thread() { override def run() { keepPushingBlocks() } } - private var currentBuffer = new ArrayBuffer[Any] - private var stopped = false + @volatile private var currentBuffer = new ArrayBuffer[Any] + @volatile private var stopped = false /** Start block generating and pushing threads. */ def start() { @@ -66,21 +68,15 @@ private[streaming] class BlockGenerator( /** Stop all threads. */ def stop() { - // Stop generating blocks - blockIntervalTimer.stop() - - // Mark as stopped - synchronized { stopped = true } - - // Wait for all blocks to be pushed - logDebug("Waiting for block pushing thread to terminate") + blockIntervalTimer.stop(false) + stopped = true blockPushingThread.join() logInfo("Stopped BlockGenerator") } /** * Push a single data item into the buffer. All received data items - * will be periodically coallesced into blocks and pushed into BlockManager. + * will be periodically pushed into BlockManager. */ def += (data: Any): Unit = synchronized { currentBuffer += data @@ -108,9 +104,8 @@ private[streaming] class BlockGenerator( /** Keep pushing blocks to the BlockManager. */ private def keepPushingBlocks() { logInfo("Started block pushing thread") - try { - while(!isStopped) { + while(!stopped) { Option(blocksForPushing.poll(100, TimeUnit.MILLISECONDS)) match { case Some(block) => pushBlock(block) case None => @@ -142,6 +137,4 @@ private[streaming] class BlockGenerator( listener.onPushBlock(block.id, block.buffer) logInfo("Pushed block " + block.id) } - - private def isStopped = synchronized { stopped } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala index 5ac28405462f4..173eb88276684 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/NetworkReceiverExecutorImpl.scala @@ -28,8 +28,13 @@ import akka.pattern.ask import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.streaming.scheduler.{AddBlocks, DeregisterReceiver, RegisterReceiver} -import org.apache.spark.util.AkkaUtils +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.storage.StreamBlockId +import org.apache.spark.streaming.scheduler.DeregisterReceiver +import org.apache.spark.streaming.scheduler.AddBlock +import scala.Some +import org.apache.spark.streaming.scheduler.RegisterReceiver /** * Concrete implementation of [[org.apache.spark.streaming.receiver.NetworkReceiverExecutor]] @@ -62,7 +67,9 @@ private[streaming] class NetworkReceiverExecutorImpl( Props(new Actor { override def preStart() { logInfo("Registered receiver " + receiverId) - val future = trackerActor.ask(RegisterReceiver(receiverId, self))(askTimeout) + val msg = RegisterReceiver( + receiverId, receiver.getClass.getSimpleName, Utils.localHostName(), self) + val future = trackerActor.ask(msg)(askTimeout) Await.result(future, askTimeout) } @@ -71,7 +78,7 @@ private[streaming] class NetworkReceiverExecutorImpl( logInfo("Received stop signal") stop() } - }), "NetworkReceiver-" + receiverId) + }), "NetworkReceiver-" + receiverId + "-" + System.currentTimeMillis()) /** Unique block ids if one wants to add blocks directly */ private val newBlockId = new AtomicLong(System.currentTimeMillis()) @@ -95,7 +102,7 @@ private[streaming] class NetworkReceiverExecutorImpl( blockGenerator += (data) } - /** Push a block of received data into block generator. */ + /** Push a block of received data as an ArrayBuffer into block generator. */ def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], @@ -106,10 +113,10 @@ private[streaming] class NetworkReceiverExecutorImpl( blockManager.put(blockId, arrayBuffer.asInstanceOf[ArrayBuffer[Any]], storageLevel, tellMaster = true) logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, optionalMetadata) + reportPushedBlock(blockId, arrayBuffer.size, optionalMetadata) } - /** Push a block of received data into block generator. */ + /** Push a block of received data as an iterator into block generator. */ def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], @@ -119,10 +126,10 @@ private[streaming] class NetworkReceiverExecutorImpl( val time = System.currentTimeMillis blockManager.put(blockId, iterator, storageLevel, tellMaster = true) logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, optionalMetadata) + reportPushedBlock(blockId, -1, optionalMetadata) } - /** Push a block (as bytes) into the block generator. */ + /** Push a block of received data as bytes into the block generator. */ def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], @@ -132,12 +139,12 @@ private[streaming] class NetworkReceiverExecutorImpl( val time = System.currentTimeMillis blockManager.putBytes(blockId, bytes, storageLevel, tellMaster = true) logDebug("Pushed block " + blockId + " in " + (System.currentTimeMillis - time) + " ms") - reportPushedBlock(blockId, optionalMetadata) + reportPushedBlock(blockId, -1, optionalMetadata) } /** Report pushed block */ - def reportPushedBlock(blockId: StreamBlockId, optionalMetadata: Option[Any]) { - trackerActor ! AddBlocks(receiverId, Array(blockId), optionalMetadata.orNull) + def reportPushedBlock(blockId: StreamBlockId, numRecords: Long, optionalMetadata: Option[Any]) { + trackerActor ! AddBlock(ReceivedBlockInfo(receiverId, blockId, numRecords, optionalMetadata.orNull)) logDebug("Reported block " + blockId) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala index 66c736e114372..82b06b880644b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receivers/ActorReceiver.scala @@ -44,7 +44,7 @@ object ReceiverSupervisorStrategy { * the API for pushing received data into Spark Streaming for being processed. * * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html - * + * * @example {{{ * class MyActor extends Actor with Receiver{ * def receive { @@ -165,10 +165,10 @@ private[streaming] class ActorReceiver[T: ClassTag]( def onStart() = { supervisor logInfo("Supervision tree for receivers initialized at:" + supervisor.path) + } def onStop() = { supervisor ! PoisonPill } - } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala index 7f3cd2f8eb1fd..9c69a2a4e21f5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/BatchInfo.scala @@ -29,6 +29,7 @@ import org.apache.spark.streaming.Time */ case class BatchInfo( batchTime: Time, + receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]], submissionTime: Long, processingStartTime: Option[Long], processingEndTime: Option[Long] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index c7306248b1950..e564eccba2df5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -39,16 +39,22 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { private val ssc = jobScheduler.ssc private val graph = ssc.graph + val clock = { val clockClass = ssc.sc.conf.get( "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock") Class.forName(clockClass).newInstance().asInstanceOf[Clock] } + private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds, - longTime => eventActor ! GenerateJobs(new Time(longTime))) - private lazy val checkpointWriter = - if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) + longTime => eventActor ! GenerateJobs(new Time(longTime)), "JobGenerator") + + // This is marked lazy so that this is initialized after checkpoint duration has been set + // in the context and the generator has been started. + private lazy val shouldCheckpoint = ssc.checkpointDuration != null && ssc.checkpointDir != null + + private lazy val checkpointWriter = if (shouldCheckpoint) { + new CheckpointWriter(this, ssc.conf, ssc.checkpointDir, ssc.sparkContext.hadoopConfiguration) } else { null } @@ -57,17 +63,16 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // This not being null means the scheduler has been started and not stopped private var eventActor: ActorRef = null + // last batch whose completion,checkpointing and metadata cleanup has been completed + private var lastProcessedBatch: Time = null + /** Start generation of jobs */ - def start() = synchronized { - if (eventActor != null) { - throw new SparkException("JobGenerator already started") - } + def start(): Unit = synchronized { + if (eventActor != null) return // generator has already been started eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { def receive = { - case event: JobGeneratorEvent => - logDebug("Got event of type " + event.getClass.getName) - processEvent(event) + case event: JobGeneratorEvent => processEvent(event) } }), "JobGenerator") if (ssc.isCheckpointPresent) { @@ -77,30 +82,79 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { } } - /** Stop generation of jobs */ - def stop() = synchronized { - if (eventActor != null) { - timer.stop() - ssc.env.actorSystem.stop(eventActor) - if (checkpointWriter != null) checkpointWriter.stop() - ssc.graph.stop() - logInfo("JobGenerator stopped") + /** + * Stop generation of jobs. processReceivedData = true makes this wait until jobs + * of current ongoing time interval has been generated, processed and corresponding + * checkpoints written. + */ + def stop(processReceivedData: Boolean): Unit = synchronized { + if (eventActor == null) return // generator has already been stopped + + if (processReceivedData) { + logInfo("Stopping JobGenerator gracefully") + val timeWhenStopStarted = System.currentTimeMillis() + val stopTimeout = 10 * ssc.graph.batchDuration.milliseconds + val pollTime = 100 + + // To prevent graceful stop to get stuck permanently + def hasTimedOut = { + val timedOut = System.currentTimeMillis() - timeWhenStopStarted > stopTimeout + if (timedOut) logWarning("Timed out while stopping the job generator") + timedOut + } + + // Wait until all the received blocks in the network input tracker has + // been consumed by network input DStreams, and jobs have been generated with them + logInfo("Waiting for all received blocks to be consumed for job generation") + while(!hasTimedOut && jobScheduler.networkInputTracker.hasMoreReceivedBlockIds) { + Thread.sleep(pollTime) + } + logInfo("Waited for all received blocks to be consumed for job generation") + + // Stop generating jobs + val stopTime = timer.stop(false) + graph.stop() + logInfo("Stopped generation timer") + + // Wait for the jobs to complete and checkpoints to be written + def haveAllBatchesBeenProcessed = { + lastProcessedBatch != null && lastProcessedBatch.milliseconds == stopTime + } + logInfo("Waiting for jobs to be processed and checkpoints to be written") + while (!hasTimedOut && !haveAllBatchesBeenProcessed) { + Thread.sleep(pollTime) + } + logInfo("Waited for jobs to be processed and checkpoints to be written") + } else { + logInfo("Stopping JobGenerator immediately") + // Stop timer and graph immediately, ignore unprocessed data and pending jobs + timer.stop(true) + graph.stop() } + + // Stop the actor and checkpoint writer + if (shouldCheckpoint) checkpointWriter.stop() + ssc.env.actorSystem.stop(eventActor) + logInfo("Stopped JobGenerator") } /** - * On batch completion, clear old metadata and checkpoint computation. + * Callback called when a batch has been completely processed. */ def onBatchCompletion(time: Time) { eventActor ! ClearMetadata(time) } - + + /** + * Callback called when the checkpoint of a batch has been written. + */ def onCheckpointCompletion(time: Time) { eventActor ! ClearCheckpointData(time) } /** Processes all events */ private def processEvent(event: JobGeneratorEvent) { + logDebug("Got event " + event) event match { case GenerateJobs(time) => generateJobs(time) case ClearMetadata(time) => clearMetadata(time) @@ -114,7 +168,7 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { val startTime = new Time(timer.getStartTime()) graph.start(startTime - graph.batchDuration) timer.start(startTime.milliseconds) - logInfo("JobGenerator started at " + startTime) + logInfo("Started JobGenerator at " + startTime) } /** Restarts the generator based on the information in checkpoint */ @@ -147,20 +201,27 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " + timesToReschedule.mkString(", ")) timesToReschedule.foreach(time => - jobScheduler.runJobs(time, graph.generateJobs(time)) + jobScheduler.submitJobSet(JobSet(time, graph.generateJobs(time))) ) // Restart the timer timer.start(restartTime.milliseconds) - logInfo("JobGenerator restarted at " + restartTime) + logInfo("Restarted JobGenerator at " + restartTime) } /** Generate jobs and perform checkpoint for the given `time`. */ private def generateJobs(time: Time) { SparkEnv.set(ssc.env) Try(graph.generateJobs(time)) match { - case Success(jobs) => jobScheduler.runJobs(time, jobs) - case Failure(e) => jobScheduler.reportError("Error generating jobs for time " + time, e) + case Success(jobs) => + val receivedBlockInfo = graph.getNetworkInputStreams.map { stream => + val streamId = stream.id + val receivedBlockInfo = stream.getReceivedBlockInfo(time) + (streamId, receivedBlockInfo) + }.toMap + jobScheduler.submitJobSet(JobSet(time, jobs, receivedBlockInfo)) + case Failure(e) => + jobScheduler.reportError("Error generating jobs for time " + time, e) } eventActor ! DoCheckpoint(time) } @@ -168,20 +229,32 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { /** Clear DStream metadata for the given `time`. */ private def clearMetadata(time: Time) { ssc.graph.clearMetadata(time) - eventActor ! DoCheckpoint(time) + + // If checkpointing is enabled, then checkpoint, + // else mark batch to be fully processed + if (shouldCheckpoint) { + eventActor ! DoCheckpoint(time) + } else { + markBatchFullyProcessed(time) + } } /** Clear DStream checkpoint data for the given `time`. */ private def clearCheckpointData(time: Time) { ssc.graph.clearCheckpointData(time) + markBatchFullyProcessed(time) } /** Perform checkpoint for the give `time`. */ - private def doCheckpoint(time: Time) = synchronized { - if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { + private def doCheckpoint(time: Time) { + if (shouldCheckpoint && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) { logInfo("Checkpointing graph for time " + time) ssc.graph.updateCheckpointData(time) checkpointWriter.write(new Checkpoint(ssc, time)) } } + + private def markBatchFullyProcessed(time: Time) { + lastProcessedBatch = time + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala index de675d3c7fb94..d9ada99b472ac 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobScheduler.scala @@ -39,7 +39,7 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private val jobSets = new ConcurrentHashMap[Time, JobSet] private val numConcurrentJobs = ssc.conf.getInt("spark.streaming.concurrentJobs", 1) - private val executor = Executors.newFixedThreadPool(numConcurrentJobs) + private val jobExecutor = Executors.newFixedThreadPool(numConcurrentJobs) private val jobGenerator = new JobGenerator(this) val clock = jobGenerator.clock val listenerBus = new StreamingListenerBus() @@ -50,46 +50,63 @@ class JobScheduler(val ssc: StreamingContext) extends Logging { private var eventActor: ActorRef = null - def start() = synchronized { - if (eventActor != null) { - throw new SparkException("JobScheduler already started") - } + def start(): Unit = synchronized { + if (eventActor != null) return // scheduler has already been started + logDebug("Starting JobScheduler") eventActor = ssc.env.actorSystem.actorOf(Props(new Actor { def receive = { case event: JobSchedulerEvent => processEvent(event) } }), "JobScheduler") + listenerBus.start() networkInputTracker = new NetworkInputTracker(ssc) networkInputTracker.start() - Thread.sleep(1000) jobGenerator.start() - logInfo("JobScheduler started") + logInfo("Started JobScheduler") } - def stop() = synchronized { - if (eventActor != null) { - jobGenerator.stop() - networkInputTracker.stop() - executor.shutdown() - if (!executor.awaitTermination(2, TimeUnit.SECONDS)) { - executor.shutdownNow() - } - listenerBus.stop() - ssc.env.actorSystem.stop(eventActor) - logInfo("JobScheduler stopped") + def stop(processAllReceivedData: Boolean): Unit = synchronized { + if (eventActor == null) return // scheduler has already been stopped + logDebug("Stopping JobScheduler") + + // First, stop receiving + networkInputTracker.stop() + + // Second, stop generating jobs. If it has to process all received data, + // then this will wait for all the processing through JobScheduler to be over. + jobGenerator.stop(processAllReceivedData) + + // Stop the executor for receiving new jobs + logDebug("Stopping job executor") + jobExecutor.shutdown() + + // Wait for the queued jobs to complete if indicated + val terminated = if (processAllReceivedData) { + jobExecutor.awaitTermination(1, TimeUnit.HOURS) // just a very large period of time + } else { + jobExecutor.awaitTermination(2, TimeUnit.SECONDS) } + if (!terminated) { + jobExecutor.shutdownNow() + } + logDebug("Stopped job executor") + + // Stop everything else + listenerBus.stop() + ssc.env.actorSystem.stop(eventActor) + eventActor = null + logInfo("Stopped JobScheduler") } - def runJobs(time: Time, jobs: Seq[Job]) { - if (jobs.isEmpty) { - logInfo("No jobs added for time " + time) + def submitJobSet(jobSet: JobSet) { + if (jobSet.jobs.isEmpty) { + logInfo("No jobs added for time " + jobSet.time) } else { - val jobSet = new JobSet(time, jobs) - jobSets.put(time, jobSet) - jobSet.jobs.foreach(job => executor.execute(new JobHandler(job))) - logInfo("Added jobs for time " + time) + jobSets.put(jobSet.time, jobSet) + jobSet.jobs.foreach(job => jobExecutor.execute(new JobHandler(job))) + logInfo("Added jobs for time " + jobSet.time) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index fcf303aee6cd7..a69d74362173e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -24,7 +24,11 @@ import org.apache.spark.streaming.Time * belong to the same batch. */ private[streaming] -case class JobSet(time: Time, jobs: Seq[Job]) { +case class JobSet( + time: Time, + jobs: Seq[Job], + receivedBlockInfo: Map[Int, Array[ReceivedBlockInfo]] = Map.empty + ) { private val incompleteJobs = new HashSet[Job]() private val submissionTime = System.currentTimeMillis() // when this jobset was submitted @@ -60,6 +64,7 @@ case class JobSet(time: Time, jobs: Seq[Job]) { def toBatchInfo: BatchInfo = { new BatchInfo( time, + receivedBlockInfo, submissionTime, if (processingStartTime >= 0 ) Some(processingStartTime) else None, if (processingEndTime >= 0 ) Some(processingEndTime) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala index cb0021143381b..c80defb23f071 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/NetworkInputTracker.scala @@ -17,21 +17,41 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, Queue} +import scala.collection.mutable.{HashMap, SynchronizedMap, SynchronizedQueue} import akka.actor._ - import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.SparkContext._ -import org.apache.spark.storage.BlockId +import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver.{NetworkReceiver, NetworkReceiverExecutorImpl, StopReceiver} import org.apache.spark.util.AkkaUtils +/** Information about receiver */ +case class ReceiverInfo(streamId: Int, typ: String, location: String) { + override def toString = s"$typ-$streamId" +} + +/** Information about blocks received by the network receiver */ +case class ReceivedBlockInfo( + streamId: Int, + blockId: StreamBlockId, + numRecords: Long, + metadata: Any + ) + +/** + * Messages used by the NetworkReceiver and the NetworkInputTracker to communicate + * with each other. + */ private[streaming] sealed trait NetworkInputTrackerMessage -private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) - extends NetworkInputTrackerMessage -private[streaming] case class AddBlocks(streamId: Int, blockIds: Seq[BlockId], metadata: Any) +private[streaming] case class RegisterReceiver( + streamId: Int, + typ: String, + host: String, + receiverActor: ActorRef + ) extends NetworkInputTrackerMessage +private[streaming] case class AddBlock(receivedBlockInfo: ReceivedBlockInfo) extends NetworkInputTrackerMessage private[streaming] case class DeregisterReceiver(streamId: Int, msg: String) extends NetworkInputTrackerMessage @@ -47,10 +67,11 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { val networkInputStreams = ssc.graph.getNetworkInputStreams() val networkInputStreamMap = Map(networkInputStreams.map(x => (x.id, x)): _*) val receiverExecutor = new ReceiverExecutor() - val receiverInfo = new HashMap[Int, ActorRef] - val receivedBlockIds = new HashMap[Int, Queue[BlockId]] + val receiverInfo = new HashMap[Int, ActorRef] with SynchronizedMap[Int, ActorRef] + val receivedBlockInfo = new HashMap[Int, SynchronizedQueue[ReceivedBlockInfo]] + with SynchronizedMap[Int, SynchronizedQueue[ReceivedBlockInfo]] val timeout = AkkaUtils.askTimeout(ssc.conf) - + val listenerBus = ssc.scheduler.listenerBus // actor is created when generator starts. // This not being null means the tracker has been started and not stopped @@ -58,7 +79,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { var currentTime: Time = null /** Start the actor and receiver execution thread. */ - def start() { + def start() = synchronized { if (actor != null) { throw new SparkException("NetworkInputTracker already started") } @@ -72,72 +93,110 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { } /** Stop the receiver execution thread. */ - def stop() { + def stop() = synchronized { if (!networkInputStreams.isEmpty && actor != null) { - receiverExecutor.interrupt() - receiverExecutor.stopReceivers() + // First, stop the receivers + receiverExecutor.stop() + + // Finally, stop the actor ssc.env.actorSystem.stop(actor) + actor = null logInfo("NetworkInputTracker stopped") } } /** Return all the blocks received from a receiver. */ - def getBlockIds(receiverId: Int, time: Time): Array[BlockId] = synchronized { - val queue = receivedBlockIds.synchronized { - receivedBlockIds.getOrElse(receiverId, new Queue[BlockId]()) - } - val result = queue.synchronized { - queue.dequeueAll(x => true) + def getReceivedBlockInfo(streamId: Int): Array[ReceivedBlockInfo] = { + val receivedBlockInfo = getReceivedBlockInfoQueue(streamId).dequeueAll(x => true) + logInfo("Stream " + streamId + " received " + receivedBlockInfo.size + " blocks") + receivedBlockInfo.toArray + } + + private def getReceivedBlockInfoQueue(streamId: Int) = { + receivedBlockInfo.getOrElseUpdate(streamId, new SynchronizedQueue[ReceivedBlockInfo]) + } + + /** Register a receiver */ + def registerReceiver( + streamId: Int, + typ: String, + host: String, + receiverActor: ActorRef, + sender: ActorRef + ) { + if (!networkInputStreamMap.contains(streamId)) { + throw new Exception("Register received for unexpected id " + streamId) } - logInfo("Stream " + receiverId + " received " + result.size + " blocks") - result.toArray + receiverInfo += ((streamId, receiverActor)) + ssc.scheduler.listenerBus.post(StreamingListenerReceiverStarted( + ReceiverInfo(streamId, typ, host) + )) + logInfo("Registered receiver for network stream " + streamId + " from " + sender.path.address) + } + + /** Deregister a receiver */ + def deregisterReceiver(streamId: Int, message: String) { + receiverInfo -= streamId + logError("Deregistered receiver for network stream " + streamId + " with message:\n" + message) + } + + /** Add new blocks for the given stream */ + def addBlocks(receivedBlockInfo: ReceivedBlockInfo) { + getReceivedBlockInfoQueue(receivedBlockInfo.streamId) += receivedBlockInfo + logDebug("Stream " + receivedBlockInfo.streamId + " received new blocks: " + + receivedBlockInfo.blockId) + } + + /** Check if any blocks are left to be processed */ + def hasMoreReceivedBlockIds: Boolean = { + !receivedBlockInfo.values.forall(_.isEmpty) } /** Actor to receive messages from the receivers. */ private class NetworkInputTrackerActor extends Actor { def receive = { - case RegisterReceiver(streamId, receiverActor) => { - if (!networkInputStreamMap.contains(streamId)) { - throw new Exception("Register received for unexpected id " + streamId) - } - receiverInfo += ((streamId, receiverActor)) - logInfo("Registered receiver for network stream " + streamId + " from " - + sender.path.address) + case RegisterReceiver(streamId, typ, host, receiverActor) => + registerReceiver(streamId, typ, host, receiverActor, sender) + sender ! true + case AddBlock(receivedBlockInfo) => + addBlocks(receivedBlockInfo) + case DeregisterReceiver(streamId, message) => + deregisterReceiver(streamId, message) sender ! true - } - case AddBlocks(streamId, blockIds, metadata) => { - val tmp = receivedBlockIds.synchronized { - if (!receivedBlockIds.contains(streamId)) { - receivedBlockIds += ((streamId, new Queue[BlockId])) - } - receivedBlockIds(streamId) - } - tmp.synchronized { - tmp ++= blockIds - } - networkInputStreamMap(streamId).addMetadata(metadata) - } - case DeregisterReceiver(streamId, msg) => { - receiverInfo -= streamId - logError("De-registered receiver for network stream " + streamId - + " with message " + msg) - // TODO: Do something about the corresponding NetworkInputDStream - } } } /** This thread class runs all the receivers on the cluster. */ - class ReceiverExecutor extends Thread { - val env = ssc.env - - override def run() { - try { - SparkEnv.set(env) - startReceivers() - } catch { - case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") - } finally { - stopReceivers() + class ReceiverExecutor { + @transient val env = ssc.env + @transient val thread = new Thread() { + override def run() { + try { + SparkEnv.set(env) + startReceivers() + } catch { + case ie: InterruptedException => logInfo("ReceiverExecutor interrupted") + } + } + } + + def start() { + thread.start() + } + + def stop() { + // Send the stop signal to all the receivers + stopReceivers() + + // Wait for the Spark job that runs the receivers to be over + // That is, for the receivers to quit gracefully. + thread.join(10000) + + // Check if all the receivers have been deregistered or not + if (!receiverInfo.isEmpty) { + logWarning("All of the receivers have not deregistered, " + receiverInfo) + } else { + logInfo("All of the receivers have deregistered successfully") } } @@ -145,7 +204,7 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { * Get the receivers from the NetworkInputDStreams, distributes them to the * worker nodes as a parallel collection, and runs them. */ - def startReceivers() { + private def startReceivers() { val receivers = networkInputStreams.map(nis => { val rcvr = nis.getReceiver() rcvr.setReceiverId(nis.id) @@ -182,13 +241,16 @@ class NetworkInputTracker(ssc: StreamingContext) extends Logging { } // Distribute the receivers and start them + logInfo("Starting " + receivers.length + " receivers") ssc.sparkContext.runJob(tempRDD, startReceiver) + logInfo("All of the receivers have been terminated") } /** Stops the receivers. */ - def stopReceivers() { + private def stopReceivers() { // Signal the receivers to stop receiverInfo.values.foreach(_ ! StopReceiver) + logInfo("Sent stop signal to all " + receiverInfo.size + " receivers") } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index 461ea3506477f..5db40ebbeb1de 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -23,8 +23,11 @@ import org.apache.spark.util.Distribution /** Base trait for events related to StreamingListener */ sealed trait StreamingListenerEvent +case class StreamingListenerBatchSubmitted(batchInfo: BatchInfo) extends StreamingListenerEvent case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent +case class StreamingListenerReceiverStarted(receiverInfo: ReceiverInfo) + extends StreamingListenerEvent /** An event used in the listener to shutdown the listener daemon thread. */ private[scheduler] case object StreamingListenerShutdown extends StreamingListenerEvent @@ -34,14 +37,17 @@ private[scheduler] case object StreamingListenerShutdown extends StreamingListen * computation. */ trait StreamingListener { - /** - * Called when processing of a batch has completed - */ + + /** Called when a receiver has been started */ + def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) { } + + /** Called when a batch of jobs has been submitted for processing. */ + def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) { } + + /** Called when processing of a batch of jobs has completed. */ def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { } - /** - * Called when processing of a batch has started - */ + /** Called when processing of a batch of jobs has started. */ def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 18811fc2b01d8..ea03dfc7bfeea 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -38,6 +38,10 @@ private[spark] class StreamingListenerBus() extends Logging { while (true) { val event = eventQueue.take event match { + case receiverStarted: StreamingListenerReceiverStarted => + listeners.foreach(_.onReceiverStarted(receiverStarted)) + case batchSubmitted: StreamingListenerBatchSubmitted => + listeners.foreach(_.onBatchSubmitted(batchSubmitted)) case batchStarted: StreamingListenerBatchStarted => listeners.foreach(_.onBatchStarted(batchStarted)) case batchCompleted: StreamingListenerBatchCompleted => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala new file mode 100644 index 0000000000000..8b025b09ed34d --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.scheduler._ +import scala.collection.mutable.{Queue, HashMap} +import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted +import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted +import org.apache.spark.streaming.scheduler.BatchInfo +import org.apache.spark.streaming.scheduler.ReceiverInfo +import org.apache.spark.streaming.scheduler.StreamingListenerBatchSubmitted +import org.apache.spark.util.Distribution + + +private[ui] class StreamingJobProgressListener(ssc: StreamingContext) extends StreamingListener { + + private val waitingBatchInfos = new HashMap[Time, BatchInfo] + private val runningBatchInfos = new HashMap[Time, BatchInfo] + private val completedaBatchInfos = new Queue[BatchInfo] + private val batchInfoLimit = ssc.conf.getInt("spark.streaming.ui.retainedBatches", 100) + private var totalCompletedBatches = 0L + private val receiverInfos = new HashMap[Int, ReceiverInfo] + + val batchDuration = ssc.graph.batchDuration.milliseconds + + override def onReceiverStarted(receiverStarted: StreamingListenerReceiverStarted) = { + synchronized { + receiverInfos.put(receiverStarted.receiverInfo.streamId, receiverStarted.receiverInfo) + } + } + + override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted) = synchronized { + runningBatchInfos(batchSubmitted.batchInfo.batchTime) = batchSubmitted.batchInfo + } + + override def onBatchStarted(batchStarted: StreamingListenerBatchStarted) = synchronized { + runningBatchInfos(batchStarted.batchInfo.batchTime) = batchStarted.batchInfo + waitingBatchInfos.remove(batchStarted.batchInfo.batchTime) + } + + override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) = synchronized { + waitingBatchInfos.remove(batchCompleted.batchInfo.batchTime) + runningBatchInfos.remove(batchCompleted.batchInfo.batchTime) + completedaBatchInfos.enqueue(batchCompleted.batchInfo) + if (completedaBatchInfos.size > batchInfoLimit) completedaBatchInfos.dequeue() + totalCompletedBatches += 1L + } + + def numNetworkReceivers = synchronized { + ssc.graph.getNetworkInputStreams().size + } + + def numTotalCompletedBatches: Long = synchronized { + totalCompletedBatches + } + + def numUnprocessedBatches: Long = synchronized { + waitingBatchInfos.size + runningBatchInfos.size + } + + def waitingBatches: Seq[BatchInfo] = synchronized { + waitingBatchInfos.values.toSeq + } + + def runningBatches: Seq[BatchInfo] = synchronized { + runningBatchInfos.values.toSeq + } + + def retainedCompletedBatches: Seq[BatchInfo] = synchronized { + completedaBatchInfos.toSeq + } + + def processingDelayDistribution: Option[Distribution] = synchronized { + extractDistribution(_.processingDelay) + } + + def schedulingDelayDistribution: Option[Distribution] = synchronized { + extractDistribution(_.schedulingDelay) + } + + def totalDelayDistribution: Option[Distribution] = synchronized { + extractDistribution(_.totalDelay) + } + + def receivedRecordsDistributions: Map[Int, Option[Distribution]] = synchronized { + val latestBatchInfos = retainedBatches.reverse.take(batchInfoLimit) + val latestBlockInfos = latestBatchInfos.map(_.receivedBlockInfo) + (0 until numNetworkReceivers).map { receiverId => + val blockInfoOfParticularReceiver = latestBlockInfos.map { batchInfo => + batchInfo.get(receiverId).getOrElse(Array.empty) + } + val recordsOfParticularReceiver = blockInfoOfParticularReceiver.map { blockInfo => + // calculate records per second for each batch + blockInfo.map(_.numRecords).sum.toDouble * 1000 / batchDuration + } + val distributionOption = Distribution(recordsOfParticularReceiver) + (receiverId, distributionOption) + }.toMap + } + + def lastReceivedBatchRecords: Map[Int, Long] = { + val lastReceivedBlockInfoOption = lastReceivedBatch.map(_.receivedBlockInfo) + lastReceivedBlockInfoOption.map { lastReceivedBlockInfo => + (0 until numNetworkReceivers).map { receiverId => + (receiverId, lastReceivedBlockInfo(receiverId).map(_.numRecords).sum) + }.toMap + }.getOrElse { + (0 until numNetworkReceivers).map(receiverId => (receiverId, 0L)).toMap + } + } + + def receiverInfo(receiverId: Int): Option[ReceiverInfo] = { + receiverInfos.get(receiverId) + } + + def lastCompletedBatch: Option[BatchInfo] = { + completedaBatchInfos.sortBy(_.batchTime)(Time.ordering).lastOption + } + + def lastReceivedBatch: Option[BatchInfo] = { + retainedBatches.lastOption + } + + private def retainedBatches: Seq[BatchInfo] = synchronized { + (waitingBatchInfos.values.toSeq ++ + runningBatchInfos.values.toSeq ++ completedaBatchInfos).sortBy(_.batchTime)(Time.ordering) + } + + private def extractDistribution(getMetric: BatchInfo => Option[Long]): Option[Distribution] = { + Distribution(completedaBatchInfos.flatMap(getMetric(_)).map(_.toDouble)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala new file mode 100644 index 0000000000000..6607437db560a --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingPage.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import java.util.Calendar +import javax.servlet.http.HttpServletRequest + +import scala.xml.Node + +import org.apache.spark.Logging +import org.apache.spark.ui._ +import org.apache.spark.ui.UIUtils._ +import org.apache.spark.util.Distribution + +/** Page for Spark Web UI that shows statistics of a streaming job */ +private[ui] class StreamingPage(parent: StreamingTab) + extends WebUIPage("") with Logging { + + private val listener = parent.listener + private val startTime = Calendar.getInstance().getTime() + private val emptyCell = "-" + + /** Render the page */ + def render(request: HttpServletRequest): Seq[Node] = { + val content = + generateBasicStats() ++

      ++ +

      Statistics over last {listener.retainedCompletedBatches.size} processed batches

      ++ + generateNetworkStatsTable() ++ + generateBatchStatsTable() + UIUtils.headerSparkPage( + content, parent.basePath, parent.appName, "Streaming", parent.headerTabs, parent, Some(5000)) + } + + /** Generate basic stats of the streaming program */ + private def generateBasicStats(): Seq[Node] = { + val timeSinceStart = System.currentTimeMillis() - startTime.getTime +
        +
      • + Started at: {startTime.toString} +
      • +
      • + Time since start: {formatDurationVerbose(timeSinceStart)} +
      • +
      • + Network receivers: {listener.numNetworkReceivers} +
      • +
      • + Batch interval: {formatDurationVerbose(listener.batchDuration)} +
      • +
      • + Processed batches: {listener.numTotalCompletedBatches} +
      • +
      • + Waiting batches: {listener.numUnprocessedBatches} +
      • +
      + } + + /** Generate stats of data received over the network the streaming program */ + private def generateNetworkStatsTable(): Seq[Node] = { + val receivedRecordDistributions = listener.receivedRecordsDistributions + val lastBatchReceivedRecord = listener.lastReceivedBatchRecords + val table = if (receivedRecordDistributions.size > 0) { + val headerRow = Seq( + "Receiver", + "Location", + "Records in last batch\n[" + formatDate(Calendar.getInstance().getTime()) + "]", + "Minimum rate\n[records/sec]", + "25th percentile rate\n[records/sec]", + "Median rate\n[records/sec]", + "75th percentile rate\n[records/sec]", + "Maximum rate\n[records/sec]" + ) + val dataRows = (0 until listener.numNetworkReceivers).map { receiverId => + val receiverInfo = listener.receiverInfo(receiverId) + val receiverName = receiverInfo.map(_.toString).getOrElse(s"Receiver-$receiverId") + val receiverLocation = receiverInfo.map(_.location).getOrElse(emptyCell) + val receiverLastBatchRecords = formatDurationVerbose(lastBatchReceivedRecord(receiverId)) + val receivedRecordStats = receivedRecordDistributions(receiverId).map { d => + d.getQuantiles().map(r => formatDurationVerbose(r.toLong)) + }.getOrElse { + Seq(emptyCell, emptyCell, emptyCell, emptyCell, emptyCell) + } + Seq(receiverName, receiverLocation, receiverLastBatchRecords) ++ receivedRecordStats + } + Some(listingTable(headerRow, dataRows)) + } else { + None + } + + val content = +
      Network Input Statistics
      ++ +
      {table.getOrElse("No network receivers")}
      + + content + } + + /** Generate stats of batch jobs of the streaming program */ + private def generateBatchStatsTable(): Seq[Node] = { + val numBatches = listener.retainedCompletedBatches.size + val lastCompletedBatch = listener.lastCompletedBatch + val table = if (numBatches > 0) { + val processingDelayQuantilesRow = { + Seq( + "Processing Time", + formatDurationOption(lastCompletedBatch.flatMap(_.processingDelay)) + ) ++ getQuantiles(listener.processingDelayDistribution) + } + val schedulingDelayQuantilesRow = { + Seq( + "Scheduling Delay", + formatDurationOption(lastCompletedBatch.flatMap(_.schedulingDelay)) + ) ++ getQuantiles(listener.schedulingDelayDistribution) + } + val totalDelayQuantilesRow = { + Seq( + "Total Delay", + formatDurationOption(lastCompletedBatch.flatMap(_.totalDelay)) + ) ++ getQuantiles(listener.totalDelayDistribution) + } + val headerRow = Seq("Metric", "Last batch", "Minimum", "25th percentile", + "Median", "75th percentile", "Maximum") + val dataRows: Seq[Seq[String]] = Seq( + processingDelayQuantilesRow, + schedulingDelayQuantilesRow, + totalDelayQuantilesRow + ) + Some(listingTable(headerRow, dataRows)) + } else { + None + } + + val content = +
      Batch Processing Statistics
      ++ +
      +
        + {table.getOrElse("No statistics have been generated yet.")} +
      +
      + + content + } + + + /** + * Returns a human-readable string representing a duration such as "5 second 35 ms" + */ + private def formatDurationOption(msOption: Option[Long]): String = { + msOption.map(formatDurationVerbose).getOrElse(emptyCell) + } + + /** Get quantiles for any time distribution */ + private def getQuantiles(timeDistributionOption: Option[Distribution]) = { + timeDistributionOption.get.getQuantiles().map { ms => formatDurationVerbose(ms.toLong) } + } + + /** Generate HTML table from string data */ + private def listingTable(headers: Seq[String], data: Seq[Seq[String]]) = { + def generateDataRow(data: Seq[String]): Seq[Node] = { +
      {data.map(d => )} + } + UIUtils.listingTable(headers, generateDataRow, data, fixedWidth = true) + } +} + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala new file mode 100644 index 0000000000000..51448d15c6516 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.ui + +import org.apache.spark.Logging +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.ui.WebUITab + +/** Spark Web UI tab that shows statistics of a streaming job */ +private[spark] class StreamingTab(ssc: StreamingContext) + extends WebUITab(ssc.sc.ui, "streaming") with Logging { + + val parent = ssc.sc.ui + val appName = parent.appName + val basePath = parent.basePath + val listener = new StreamingJobProgressListener(ssc) + + ssc.addStreamingListener(listener) + attachPage(new StreamingPage(this)) + parent.attachTab(this) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala index c3a849d2769a7..39145a3ab081a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala @@ -19,46 +19,43 @@ package org.apache.spark.streaming.util private[streaming] trait Clock { - def currentTime(): Long + def currentTime(): Long def waitTillTime(targetTime: Long): Long } private[streaming] class SystemClock() extends Clock { - + val minPollTime = 25L - + def currentTime(): Long = { System.currentTimeMillis() - } - + } + def waitTillTime(targetTime: Long): Long = { var currentTime = 0L currentTime = System.currentTimeMillis() - + var waitTime = targetTime - currentTime if (waitTime <= 0) { return currentTime } - + val pollTime = { if (waitTime / 10.0 > minPollTime) { (waitTime / 10.0).toLong } else { - minPollTime - } + minPollTime + } } - - + while (true) { currentTime = System.currentTimeMillis() waitTime = targetTime - currentTime - if (waitTime <= 0) { - return currentTime } - val sleepTime = + val sleepTime = if (waitTime < pollTime) { waitTime } else { @@ -72,7 +69,7 @@ class SystemClock() extends Clock { private[streaming] class ManualClock() extends Clock { - + var time = 0L def currentTime() = time @@ -88,13 +85,13 @@ class ManualClock() extends Clock { this.synchronized { time += timeToAdd this.notifyAll() - } + } } def waitTillTime(targetTime: Long): Long = { this.synchronized { while (time < targetTime) { this.wait(100) - } + } } currentTime() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala index 07021ebb5802a..bbf57ef9275c0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextHelper.scala @@ -19,18 +19,17 @@ package org.apache.spark.streaming.util import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ -import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap} +import org.apache.spark.util.collection.OpenHashMap import scala.collection.JavaConversions.mapAsScalaMap private[streaming] object RawTextHelper { /** - * Splits lines and counts the words in them using specialized object-to-long hashmap - * (to avoid boxing-unboxing overhead of Long in java/scala HashMap) + * Splits lines and counts the words. */ def splitAndCountPartitions(iter: Iterator[String]): Iterator[(String, Long)] = { - val map = new OLMap[String] + val map = new OpenHashMap[String,Long] var i = 0 var j = 0 while (iter.hasNext) { @@ -43,25 +42,27 @@ object RawTextHelper { } if (j > i) { val w = s.substring(i, j) - val c = map.getLong(w) - map.put(w, c + 1) + map.changeValue(w, 1L, _ + 1L) } i = j while (i < s.length && s.charAt(i) == ' ') { i += 1 } } + map.toIterator.map { + case (k, v) => (k, v) + } } map.toIterator.map{case (k, v) => (k, v)} } - /** + /** * Gets the top k words in terms of word counts. Assumes that each word exists only once * in the `data` iterator (that is, the counts have been reduced). */ def topK(data: Iterator[(String, Long)], k: Int): Iterator[(String, Long)] = { val taken = new Array[(String, Long)](k) - + var i = 0 var len = 0 var done = false @@ -93,7 +94,7 @@ object RawTextHelper { } taken.toIterator } - + /** * Warms up the SparkContext in master and slave by running tasks to force JIT kick in * before real workload starts. @@ -106,11 +107,11 @@ object RawTextHelper { .count() } } - - def add(v1: Long, v2: Long) = (v1 + v2) - def subtract(v1: Long, v2: Long) = (v1 - v2) + def add(v1: Long, v2: Long) = (v1 + v2) + + def subtract(v1: Long, v2: Long) = (v1 - v2) - def max(v1: Long, v2: Long) = math.max(v1, v2) + def max(v1: Long, v2: Long) = math.max(v1, v2) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index 684b38e8b3102..a7850812bd612 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -17,14 +17,12 @@ package org.apache.spark.streaming.util -import java.io.IOException +import java.io.{ByteArrayOutputStream, IOException} import java.net.ServerSocket import java.nio.ByteBuffer import scala.io.Source -import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream - import org.apache.spark.{SparkConf, Logging} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.IntParam @@ -45,16 +43,15 @@ object RawTextSender extends Logging { // Repeat the input data multiple times to fill in a buffer val lines = Source.fromFile(file).getLines().toArray - val bufferStream = new FastByteArrayOutputStream(blockSize + 1000) + val bufferStream = new ByteArrayOutputStream(blockSize + 1000) val ser = new KryoSerializer(new SparkConf()).newInstance() val serStream = ser.serializeStream(bufferStream) var i = 0 - while (bufferStream.position < blockSize) { + while (bufferStream.size < blockSize) { serStream.writeObject(lines(i)) i = (i + 1) % lines.length } - bufferStream.trim() - val array = bufferStream.array + val array = bufferStream.toByteArray val countBuf = ByteBuffer.wrap(new Array[Byte](4)) countBuf.putInt(array.length) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala index 559c2473851b3..e016377c94c0d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala @@ -17,44 +17,84 @@ package org.apache.spark.streaming.util +import org.apache.spark.Logging + private[streaming] -class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => Unit) { - - private val thread = new Thread("RecurringTimer") { - override def run() { loop } +class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String) + extends Logging { + + private val thread = new Thread("RecurringTimer - " + name) { + setDaemon(true) + override def run() { loop } } - - private var nextTime = 0L + @volatile private var prevTime = -1L + @volatile private var nextTime = -1L + @volatile private var stopped = false + + /** + * Get the time when this timer will fire if it is started right now. + * The time will be a multiple of this timer's period and more than + * current system time. + */ def getStartTime(): Long = { (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period } + /** + * Get the time when the timer will fire if it is restarted right now. + * This time depends on when the timer was started the first time, and was stopped + * for whatever reason. The time must be a multiple of this timer's period and + * more than current time. + */ def getRestartTime(originalStartTime: Long): Long = { val gap = clock.currentTime - originalStartTime (math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime } - def start(startTime: Long): Long = { + /** + * Start at the given start time. + */ + def start(startTime: Long): Long = synchronized { nextTime = startTime thread.start() + logInfo("Started timer for " + name + " at time " + nextTime) nextTime } + /** + * Start at the earliest time it can start based on the period. + */ def start(): Long = { start(getStartTime()) } - def stop() { - thread.interrupt() + /** + * Stop the timer, and return the last time the callback was made. + * interruptTimer = true will interrupt the callback + * if it is in progress (not guaranteed to give correct time in this case). + */ + def stop(interruptTimer: Boolean): Long = synchronized { + if (!stopped) { + stopped = true + if (interruptTimer) thread.interrupt() + thread.join() + logInfo("Stopped timer for " + name + " after time " + prevTime) + } + prevTime } - + + /** + * Repeatedly call the callback every interval. + */ private def loop() { try { - while (true) { + while (!stopped) { clock.waitTillTime(nextTime) callback(nextTime) + prevTime = nextTime nextTime += period + logDebug("Callback for " + name + " called at time " + prevTime) } } catch { case e: InterruptedException => @@ -64,20 +104,20 @@ class RecurringTimer(val clock: Clock, val period: Long, val callback: (Long) => private[streaming] object RecurringTimer { - + def main(args: Array[String]) { var lastRecurTime = 0L val period = 1000 - + def onRecur(time: Long) { val currentTime = System.currentTimeMillis() println("" + currentTime + ": " + (currentTime - lastRecurTime)) lastRecurTime = currentTime } - val timer = new RecurringTimer(new SystemClock(), period, onRecur) + val timer = new RecurringTimer(new SystemClock(), period, onRecur, "Test") timer.start() Thread.sleep(30 * 1000) - timer.stop() + timer.stop(true) } } diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index e93bf18b6d0b9..a0b1bbc34fa7c 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -23,6 +23,7 @@ import org.junit.Test; import java.io.*; import java.util.*; +import java.lang.Iterable; import com.google.common.base.Optional; import com.google.common.collect.Lists; @@ -45,6 +46,18 @@ // see http://stackoverflow.com/questions/758570/. public class JavaAPISuite extends LocalJavaStreamingContext implements Serializable { + public void equalIterator(Iterator a, Iterator b) { + while (a.hasNext() && b.hasNext()) { + Assert.assertEquals(a.next(), b.next()); + } + Assert.assertEquals(a.hasNext(), b.hasNext()); + } + + public void equalIterable(Iterable a, Iterable b) { + equalIterator(a.iterator(), b.iterator()); + } + + @SuppressWarnings("unchecked") @Test public void testCount() { @@ -1016,11 +1029,24 @@ public void testPairGroupByKey() { JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream> grouped = pairStream.groupByKey(); + JavaPairDStream> grouped = pairStream.groupByKey(); JavaTestUtils.attachTestOutputStream(grouped); - List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); + List>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected.size(), result.size()); + Iterator>>> resultItr = result.iterator(); + Iterator>>> expectedItr = expected.iterator(); + while (resultItr.hasNext() && expectedItr.hasNext()) { + Iterator>> resultElements = resultItr.next().iterator(); + Iterator>> expectedElements = expectedItr.next().iterator(); + while (resultElements.hasNext() && expectedElements.hasNext()) { + Tuple2> resultElement = resultElements.next(); + Tuple2> expectedElement = expectedElements.next(); + Assert.assertEquals(expectedElement._1(), resultElement._1()); + equalIterable(expectedElement._2(), resultElement._2()); + } + Assert.assertEquals(resultElements.hasNext(), expectedElements.hasNext()); + } } @SuppressWarnings("unchecked") @@ -1128,7 +1154,7 @@ public void testGroupByKeyAndWindow() { JavaDStream> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1); JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); - JavaPairDStream> groupWindowed = + JavaPairDStream> groupWindowed = pairStream.groupByKeyAndWindow(new Duration(2000), new Duration(1000)); JavaTestUtils.attachTestOutputStream(groupWindowed); List>>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -1471,11 +1497,25 @@ public void testCoGroup() { ssc, stringStringKVStream2, 1); JavaPairDStream pairStream2 = JavaPairDStream.fromJavaDStream(stream2); - JavaPairDStream, List>> grouped = pairStream1.cogroup(pairStream2); + JavaPairDStream, Iterable>> grouped = pairStream1.cogroup(pairStream2); JavaTestUtils.attachTestOutputStream(grouped); - List, List>>>> result = JavaTestUtils.runStreams(ssc, 2, 2); - - Assert.assertEquals(expected, result); + List, Iterable>>>> result = JavaTestUtils.runStreams(ssc, 2, 2); + + Assert.assertEquals(expected.size(), result.size()); + Iterator, Iterable>>>> resultItr = result.iterator(); + Iterator, List>>>> expectedItr = expected.iterator(); + while (resultItr.hasNext() && expectedItr.hasNext()) { + Iterator, Iterable>>> resultElements = resultItr.next().iterator(); + Iterator, List>>> expectedElements = expectedItr.next().iterator(); + while (resultElements.hasNext() && expectedElements.hasNext()) { + Tuple2, Iterable>> resultElement = resultElements.next(); + Tuple2, List>> expectedElement = expectedElements.next(); + Assert.assertEquals(expectedElement._1(), resultElement._1()); + equalIterable(expectedElement._2()._1(), resultElement._2()._1()); + equalIterable(expectedElement._2()._2(), resultElement._2()._2()); + } + Assert.assertEquals(resultElements.hasNext(), expectedElements.hasNext()); + } } @SuppressWarnings("unchecked") @@ -1633,7 +1673,7 @@ public void testSocketTextStream() { @Test public void testSocketString() { - + class Converter implements Function> { public Iterable call(InputStream in) throws IOException { BufferedReader reader = new BufferedReader(new InputStreamReader(in)); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index bcb0c28bf07a0..8aec27e39478a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -117,7 +117,7 @@ class BasicOperationsSuite extends TestSuiteBase { test("groupByKey") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), - (s: DStream[String]) => s.map(x => (x, 1)).groupByKey(), + (s: DStream[String]) => s.map(x => (x, 1)).groupByKey().mapValues(_.toSeq), Seq( Seq(("a", Seq(1, 1)), ("b", Seq(1))), Seq(("", Seq(1, 1))), Seq() ), true ) @@ -251,7 +251,7 @@ class BasicOperationsSuite extends TestSuiteBase { Seq( ) ) val operation = (s1: DStream[String], s2: DStream[String]) => { - s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))) + s1.map(x => (x,1)).cogroup(s2.map(x => (x, "x"))).mapValues(x => (x._1.toSeq, x._2.toSeq)) } testOperation(inputData1, inputData2, operation, outputData, true) } @@ -324,7 +324,7 @@ class BasicOperationsSuite extends TestSuiteBase { val updateStateOperation = (s: DStream[String]) => { val updateFunc = (values: Seq[Int], state: Option[Int]) => { - Some(values.foldLeft(0)(_ + _) + state.getOrElse(0)) + Some(values.sum + state.getOrElse(0)) } s.map(x => (x, 1)).updateStateByKey[Int](updateFunc) } @@ -359,7 +359,7 @@ class BasicOperationsSuite extends TestSuiteBase { // updateFunc clears a state when a StateObject is seen without new values twice in a row val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { val stateObj = state.getOrElse(new StateObject) - values.foldLeft(0)(_ + _) match { + values.sum match { case 0 => stateObj.expireCounter += 1 // no new values case n => { // has new values, increment and reset expireCounter stateObj.counter += n diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 4ae23184d7c80..a5d68ab1777e3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -144,8 +144,8 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") } - - test("actor input stream") { + // TODO: This test makes assumptions about Thread.sleep() and is flaky + ignore("actor input stream") { // Start the server val testServer = new TestServer() val port = testServer.port @@ -239,11 +239,11 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { /** This is a server to test the network input stream */ -class TestServer() extends Logging { +class TestServer(portToBind: Int = 0) extends Logging { val queue = new ArrayBlockingQueue[String](100) - val serverSocket = new ServerSocket(0) + val serverSocket = new ServerSocket(portToBind) val servingThread = new Thread() { override def run() { @@ -282,7 +282,7 @@ class TestServer() extends Logging { def start() { servingThread.start() } - def send(msg: String) { queue.add(msg) } + def send(msg: String) { queue.put(msg) } def stop() { servingThread.interrupt() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 6e16bbfb4a109..1b81f2643cc51 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -17,19 +17,23 @@ package org.apache.spark.streaming -import org.scalatest.{FunSuite, BeforeAndAfter} -import org.scalatest.exceptions.TestFailedDueToTimeoutException +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.receiver.NetworkReceiver +import org.apache.spark.util.{MetadataCleaner, Utils} +import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Timeouts +import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkException, SparkConf, SparkContext} -import org.apache.spark.util.{Utils, MetadataCleaner} -import org.apache.spark.streaming.dstream.DStream -class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { +class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { val master = "local[2]" val appName = this.getClass.getSimpleName - val batchDuration = Seconds(1) + val batchDuration = Milliseconds(500) val sparkHome = "someDir" val envPair = "key" -> "value" val ttl = StreamingContext.DEFAULT_CLEANER_TTL + 100 @@ -51,7 +55,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { sc = null } } - +/* test("from no conf constructor") { ssc = new StreamingContext(master, appName, batchDuration) assert(ssc.sparkContext.conf.get("spark.master") === master) @@ -108,19 +112,31 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.cleaner.ttl", ttl.toString) val ssc1 = new StreamingContext(myConf, batchDuration) + addInputStream(ssc1).register + ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) assert(MetadataCleaner.getDelaySeconds(cp.sparkConf) === ttl) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) assert(MetadataCleaner.getDelaySeconds(newCp.sparkConf) === ttl) - ssc = new StreamingContext(null, cp, null) + ssc = new StreamingContext(null, newCp, null) assert(MetadataCleaner.getDelaySeconds(ssc.conf) === ttl) } - test("start multiple times") { + test("start and stop state check") { ssc = new StreamingContext(master, appName, batchDuration) addInputStream(ssc).register + assert(ssc.state === ssc.StreamingContextState.Initialized) + ssc.start() + assert(ssc.state === ssc.StreamingContextState.Started) + ssc.stop() + assert(ssc.state === ssc.StreamingContextState.Stopped) + } + + test("start multiple times") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register ssc.start() intercept[SparkException] { ssc.start() @@ -133,7 +149,17 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { ssc.start() ssc.stop() ssc.stop() - ssc = null + } + + test("stop before start and start after stop") { + ssc = new StreamingContext(master, appName, batchDuration) + addInputStream(ssc).register + ssc.stop() // stop before start should not throw exception + ssc.start() + ssc.stop() + intercept[SparkException] { + ssc.start() // start after stop should throw exception + } } test("stop only streaming context") { @@ -142,14 +168,44 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { addInputStream(ssc).register ssc.start() ssc.stop(false) - ssc = null assert(sc.makeRDD(1 to 100).collect().size === 100) ssc = new StreamingContext(sc, batchDuration) addInputStream(ssc).register ssc.start() ssc.stop() } +*/ + test("stop gracefully") { + val conf = new SparkConf().setMaster(master).setAppName(appName) + conf.set("spark.cleaner.ttl", "3600") + sc = new SparkContext(conf) + for (i <- 1 to 4) { + logInfo("==================================") + ssc = new StreamingContext(sc, batchDuration) + var runningCount = 0 + TestReceiver.counter.set(1) + val input = ssc.networkStream(new TestReceiver) + input.count.foreachRDD(rdd => { + val count = rdd.first() + runningCount += count.toInt + logInfo("Count = " + count + ", Running count = " + runningCount) + }) + ssc.start() + ssc.awaitTermination(500) + ssc.stop(stopSparkContext = false, stopGracefully = true) + logInfo("Running count = " + runningCount) + logInfo("TestReceiver.counter = " + TestReceiver.counter.get()) + assert(runningCount > 0) + assert( + (TestReceiver.counter.get() == runningCount + 1) || + (TestReceiver.counter.get() == runningCount + 2), + "Received records = " + TestReceiver.counter.get() + ", " + + "processed records = " + runningCount + ) + } + } +/* test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) @@ -202,7 +258,6 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { test("awaitTermination with error in job generation") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) - inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register val exception = intercept[TestException] { ssc.start() @@ -210,7 +265,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { } assert(exception.getMessage.contains("transform"), "Expected exception not thrown") } - +*/ def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => (1 to i)) val inputStream = new TestInputStream(s, input, 1) @@ -219,3 +274,23 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts { } class TestException(msg: String) extends Exception(msg) + +/** Custom receiver for testing whether all data received by a receiver gets processed or not */ +class TestReceiver extends NetworkReceiver[Int](StorageLevel.MEMORY_ONLY) with Logging { + def onStart() { + try { + while(true) { + store(TestReceiver.counter.getAndIncrement) + Thread.sleep(0) + } + } finally { + logInfo("Receiving stopped at count value of " + TestReceiver.counter.get()) + } + } + + def onStop() { } +} + +object TestReceiver { + val counter = new AtomicInteger(1) +} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 201630672ab4c..aa2d5c2fc2454 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -277,7 +277,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { assert(timeTaken < maxWaitTimeMillis, "Operation timed out after " + timeTaken + " ms") assert(output.size === numExpectedOutput, "Unexpected number of outputs generated") - Thread.sleep(500) // Give some time for the forgetting old RDDs to complete + Thread.sleep(100) // Give some time for the forgetting old RDDs to complete } catch { case e: Exception => {e.printStackTrace(); throw e} } finally { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala new file mode 100644 index 0000000000000..35538ec188f67 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming + +import scala.io.Source + +import org.scalatest.FunSuite +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +class UISuite extends FunSuite { + + test("streaming tab in spark UI") { + val ssc = new StreamingContext("local", "test", Seconds(1)) + eventually(timeout(10 seconds), interval(50 milliseconds)) { + val html = Source.fromURL(ssc.sparkContext.ui.appUIAddress).mkString + assert(!html.contains("random data that should not be present")) + // test if streaming tab exist + assert(html.toLowerCase.contains("streaming")) + // test if other Spark tabs still exist + assert(html.toLowerCase.contains("stages")) + } + + eventually(timeout(10 seconds), interval(50 milliseconds)) { + val html = Source.fromURL( + ssc.sparkContext.ui.appUIAddress.stripSuffix("/") + "/streaming").mkString + assert(html.toLowerCase.contains("batch")) + assert(html.toLowerCase.contains("network")) + } + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala similarity index 97% rename from core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala rename to tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 8cea302eb14c3..8e8c35615a711 100644 --- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.storage +package org.apache.spark.tools import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong @@ -25,7 +25,7 @@ import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils /** - * Utility for micro-benchmarking shuffle write performance. + * Internal utility for micro-benchmarking shuffle write performance. * * Writes simulated shuffle output from several threads and records the observed throughput. */ diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 910484ed5432a..67ec95c8fc04f 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -234,7 +234,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, assert(sparkContext != null || count >= numTries) if (null != sparkContext) { - uiAddress = sparkContext.ui.appUIAddress + uiAddress = sparkContext.ui.appUIHostPort this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, resourceManager, diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 981e8b05f602d..3469b7decedf6 100644 --- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -81,7 +81,8 @@ class ExecutorRunnable( credentials.writeTokenStorageToStream(dob) ctx.setContainerTokens(ByteBuffer.wrap(dob.getData())) - val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores) + val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, + localResources.contains(ClientBase.LOG4J_PROP)) logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands) diff --git a/yarn/common/src/main/resources/log4j-spark-container.properties b/yarn/common/src/main/resources/log4j-spark-container.properties new file mode 100644 index 0000000000000..a1e37a0be27dd --- /dev/null +++ b/yarn/common/src/main/resources/log4j-spark-container.properties @@ -0,0 +1,24 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. See accompanying LICENSE file. + +# Set everything to be logged to the console +log4j.rootCategory=INFO, console +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.err +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index 6568003bf1008..eb42922aea228 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -266,11 +266,11 @@ trait ClientBase extends Logging { localResources: HashMap[String, LocalResource], stagingDir: String): HashMap[String, String] = { logInfo("Setting up the launch environment") - val log4jConfLocalRes = localResources.getOrElse(ClientBase.LOG4J_PROP, null) val env = new HashMap[String, String]() - ClientBase.populateClasspath(yarnConf, sparkConf, log4jConfLocalRes != null, env) + ClientBase.populateClasspath(yarnConf, sparkConf, localResources.contains(ClientBase.LOG4J_PROP), + env) env("SPARK_YARN_MODE") = "true" env("SPARK_YARN_STAGING_DIR") = stagingDir env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName() @@ -344,15 +344,13 @@ trait ClientBase extends Logging { JAVA_OPTS += " " + env("SPARK_JAVA_OPTS") } - // Command for the ApplicationMaster - var javaCommand = "java" - val javaHome = System.getenv("JAVA_HOME") - if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) { - javaCommand = Environment.JAVA_HOME.$() + "/bin/java" + if (!localResources.contains(ClientBase.LOG4J_PROP)) { + JAVA_OPTS += " " + YarnSparkHadoopUtil.getLoggingArgsForContainerCommandLine() } + // Command for the ApplicationMaster val commands = List[String]( - javaCommand + + Environment.JAVA_HOME.$() + "/bin/java" + " -server " + JAVA_OPTS + " " + args.amClass + diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala index da0a6f74efcd5..b3696c5fe7183 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala @@ -50,7 +50,8 @@ trait ExecutorRunnableUtil extends Logging { slaveId: String, hostname: String, executorMemory: Int, - executorCores: Int) = { + executorCores: Int, + userSpecifiedLogFile: Boolean) = { // Extra options for the JVM var JAVA_OPTS = "" // Set the JVM memory @@ -63,6 +64,10 @@ trait ExecutorRunnableUtil extends Logging { JAVA_OPTS += " -Djava.io.tmpdir=" + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " " + if (!userSpecifiedLogFile) { + JAVA_OPTS += " " + YarnSparkHadoopUtil.getLoggingArgsForContainerCommandLine() + } + // Commenting it out for now - so that people can refer to the properties if required. Remove // it once cpuset version is pushed out. // The context is, default gc for server class machines end up using all cores to do gc - hence @@ -88,13 +93,8 @@ trait ExecutorRunnableUtil extends Logging { } */ - var javaCommand = "java" - val javaHome = System.getenv("JAVA_HOME") - if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) { - javaCommand = Environment.JAVA_HOME.$() + "/bin/java" - } - - val commands = List[String](javaCommand + + val commands = List[String]( + Environment.JAVA_HOME.$() + "/bin/java" + " -server " + // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4c6e1dcd6dac3..314a7550ada71 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -22,6 +22,7 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.conf.Configuration import org.apache.spark.deploy.SparkHadoopUtil @@ -67,3 +68,9 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } } + +object YarnSparkHadoopUtil { + def getLoggingArgsForContainerCommandLine(): String = { + "-Dlog4j.configuration=log4j-spark-container.properties" + } +} diff --git a/yarn/pom.xml b/yarn/pom.xml index 35e31760c1f02..3342cb65edcd1 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -167,6 +167,12 @@ target/scala-${scala.binary.version}/classes target/scala-${scala.binary.version}/test-classes + + + + ../common/src/main/resources + + diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 30735cbfdf26e..61af0f9ac5ca0 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.yarn import java.io.IOException -import java.net.Socket import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} @@ -36,7 +35,7 @@ import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} -import org.apache.hadoop.yarn.webapp.util.WebAppUtils; +import org.apache.hadoop.yarn.webapp.util.WebAppUtils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil @@ -221,7 +220,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, assert(sparkContext != null || numTries >= maxNumTries) if (sparkContext != null) { - uiAddress = sparkContext.ui.appUIAddress + uiAddress = sparkContext.ui.appUIHostPort this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, amClient, diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 53c403f7d0913..81d9d1b5c9280 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -78,7 +78,8 @@ class ExecutorRunnable( credentials.writeTokenStorageToStream(dob) ctx.setTokens(ByteBuffer.wrap(dob.getData())) - val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores) + val commands = prepareCommand(masterAddress, slaveId, hostname, executorMemory, executorCores, + localResources.contains(ClientBase.LOG4J_PROP)) logInfo("Setting up executor with commands: " + commands) ctx.setCommands(commands)
      Storage LevelMeaning
      MEMORY_AND_DISK_SER Similar to MEMORY_ONLY_SER, but spill partitions that don't fit in memory to disk instead of recomputing them - on the fly each time they're needed. Similar to MEMORY_ONLY_SER, but spill partitions that don't fit in memory to disk instead of + recomputing them on the fly each time they're needed.
      OFF_HEAP Store RDD in a serialized format in Tachyon. + This is generally more space-efficient than deserialized objects, especially when using a + fast serializer, but more CPU-intensive to read. + This also significantly reduces the overheads of GC. +
      DISK_ONLY
      {d}