From c3f2a8588e19aab814ac5cdd86575bb5558d5e46 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 25 Sep 2014 23:20:17 +0530 Subject: [PATCH 01/50] SPARK-2932 [STREAMING] Move MasterFailureTest out of "main" source directory (HT @vanzin) Whatever the reason was for having this test class in `main`, if there is one, appear to be moot. This may have been a result of earlier streaming test reorganization. This simply puts `MasterFailureTest` back under `test/`, removes some redundant copied code, and touches up a few tiny inspection warnings along the way. Author: Sean Owen Closes #2399 from srowen/SPARK-2932 and squashes the following commits: 3909411 [Sean Owen] Move MasterFailureTest to src/test, and remove redundant TestOutputStream --- .../apache/spark/streaming/FailureSuite.scala | 1 - .../spark/streaming}/MasterFailureTest.scala | 43 ++++--------------- 2 files changed, 8 insertions(+), 36 deletions(-) rename streaming/src/{main/scala/org/apache/spark/streaming/util => test/scala/org/apache/spark/streaming}/MasterFailureTest.scala (91%) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index 92e1b76d28301..40434b1f9b709 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.streaming import org.apache.spark.Logging -import org.apache.spark.streaming.util.MasterFailureTest import org.apache.spark.util.Utils import java.io.File diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala similarity index 91% rename from streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala rename to streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 98e17ff92e205..c53c01706083a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -15,20 +15,18 @@ * limitations under the License. */ -package org.apache.spark.streaming.util +package org.apache.spark.streaming import org.apache.spark.Logging -import org.apache.spark.rdd.RDD -import org.apache.spark.streaming._ -import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} +import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils -import StreamingContext._ +import org.apache.spark.streaming.StreamingContext._ import scala.util.Random -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag -import java.io.{File, ObjectInputStream, IOException} +import java.io.{File, IOException} import java.nio.charset.Charset import java.util.UUID @@ -91,7 +89,7 @@ object MasterFailureTest extends Logging { // Input: time=1 ==> [ a ] , time=2 ==> [ a, a ] , time=3 ==> [ a, a, a ] , ... val input = (1 to numBatches).map(i => (1 to i).map(_ => "a").mkString(" ")).toSeq // Expected output: time=1 ==> [ (a, 1) ] , time=2 ==> [ (a, 3) ] , time=3 ==> [ (a,6) ] , ... - val expectedOutput = (1L to numBatches).map(i => (1L to i).reduce(_ + _)).map(j => ("a", j)) + val expectedOutput = (1L to numBatches).map(i => (1L to i).sum).map(j => ("a", j)) val operation = (st: DStream[String]) => { val updateFunc = (values: Seq[Long], state: Option[Long]) => { @@ -218,7 +216,7 @@ object MasterFailureTest extends Logging { while(!isLastOutputGenerated && !isTimedOut) { // Get the output buffer - val outputBuffer = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[T]].output + val outputBuffer = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[T]].output def output = outputBuffer.flatMap(x => x) // Start the thread to kill the streaming after some time @@ -239,7 +237,7 @@ object MasterFailureTest extends Logging { while (!killed && !isLastOutputGenerated && !isTimedOut) { Thread.sleep(100) timeRan = System.currentTimeMillis() - startTime - isLastOutputGenerated = (!output.isEmpty && output.last == lastExpectedOutput) + isLastOutputGenerated = (output.nonEmpty && output.last == lastExpectedOutput) isTimedOut = (timeRan + totalTimeRan > maxTimeToRun) } } catch { @@ -313,31 +311,6 @@ object MasterFailureTest extends Logging { } } -/** - * This is a output stream just for testing. All the output is collected into a - * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. - */ -private[streaming] -class TestOutputStream[T: ClassTag]( - parent: DStream[T], - val output: ArrayBuffer[Seq[T]] = new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]] - ) extends ForEachDStream[T]( - parent, - (rdd: RDD[T], t: Time) => { - val collected = rdd.collect() - output += collected - } - ) { - - // This is to clear the output buffer every it is read from a checkpoint - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - ois.defaultReadObject() - output.clear() - } -} - - /** * Thread to kill streaming context after a random period of time. */ From 9b56e249e09d8da20f703b9381c5c3c8a1a1d4a9 Mon Sep 17 00:00:00 2001 From: epahomov Date: Thu, 25 Sep 2014 14:50:12 -0700 Subject: [PATCH 02/50] [SPARK-3690] Closing shuffle writers we swallow more important exception Author: epahomov Closes #2537 from epahomov/SPARK-3690 and squashes the following commits: a0b7de4 [epahomov] [SPARK-3690] Closing shuffle writers we swallow more important exception --- .../org/apache/spark/scheduler/ShuffleMapTask.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 381eff2147e95..a98ee118254a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -69,8 +69,13 @@ private[spark] class ShuffleMapTask( return writer.stop(success = true).get } catch { case e: Exception => - if (writer != null) { - writer.stop(success = false) + try { + if (writer != null) { + writer.stop(success = false) + } + } catch { + case e: Exception => + log.debug("Could not stop writer", e) } throw e } finally { From ff637c9380a6342fd0a4dde0710ec23856751dd4 Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Thu, 25 Sep 2014 16:11:00 -0700 Subject: [PATCH 03/50] [SPARK-1484][MLLIB] Warn when running an iterative algorithm on uncached data. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add warnings to KMeans, GeneralizedLinearAlgorithm, and computeSVD when called with input data that is not cached. KMeans is implemented iteratively, and I believe that GeneralizedLinearAlgorithm’s current optimizers are iterative and its future optimizers are also likely to be iterative. RowMatrix’s computeSVD is iterative against an RDD when run in DistARPACK mode. ALS and DecisionTree are iterative as well, but they implement RDD caching internally so do not require a warning. I added a warning to GeneralizedLinearAlgorithm rather than inside its optimizers, where the iteration actually occurs, because internally GeneralizedLinearAlgorithm maps its input data to an uncached RDD before passing it to an optimizer. (In other words, the warning would be printed for every GeneralizedLinearAlgorithm run, regardless of whether its input is cached, if the warning were in GradientDescent or other optimizer.) I assume that use of an uncached RDD by GeneralizedLinearAlgorithm is intentional, and that the mapping there (adding label, intercepts and scaling) is a lightweight operation. Arguably a user calling an optimizer such as GradientDescent will be knowledgable enough to cache their data without needing a log warning, so lack of a warning in the optimizers may be ok. Some of the documentation examples making use of these iterative algorithms did not cache their training RDDs (while others did). I updated the examples to always cache. I also fixed some (unrelated) minor errors in the documentation examples. Author: Aaron Staple Closes #2347 from staple/SPARK-1484 and squashes the following commits: bd49701 [Aaron Staple] Address review comments. ab2d4a4 [Aaron Staple] Disable warnings on python code path. a7a0f99 [Aaron Staple] Change code comments per review comments. 7cca1dc [Aaron Staple] Change warning message text. c77e939 [Aaron Staple] [SPARK-1484][MLLIB] Warn when running an iterative algorithm on uncached data. 3b6c511 [Aaron Staple] Minor doc example fixes. --- docs/mllib-clustering.md | 3 +- docs/mllib-linear-methods.md | 9 ++-- docs/mllib-optimization.md | 1 + .../mllib/api/python/PythonMLLibAPI.scala | 54 ++++++++++--------- .../spark/mllib/clustering/KMeans.scala | 22 ++++++++ .../mllib/linalg/distributed/RowMatrix.scala | 11 ++++ .../GeneralizedLinearAlgorithm.scala | 21 ++++++++ 7 files changed, 91 insertions(+), 30 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index dfd9cd572888c..d10bd63746629 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -52,7 +52,7 @@ import org.apache.spark.mllib.linalg.Vectors // Load and parse the data val data = sc.textFile("data/mllib/kmeans_data.txt") -val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))) +val parsedData = data.map(s => Vectors.dense(s.split(' ').map(_.toDouble))).cache() // Cluster the data into two classes using KMeans val numClusters = 2 @@ -100,6 +100,7 @@ public class KMeansExample { } } ); + parsedData.cache(); // Cluster the data into two classes using KMeans int numClusters = 2; diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 9137f9dc1b692..d31bec3e1bd01 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -396,7 +396,7 @@ val data = sc.textFile("data/mllib/ridge-data/lpsa.data") val parsedData = data.map { line => val parts = line.split(',') LabeledPoint(parts(0).toDouble, Vectors.dense(parts(1).split(' ').map(_.toDouble))) -} +}.cache() // Building the model val numIterations = 100 @@ -455,6 +455,7 @@ public class LinearRegression { } } ); + parsedData.cache(); // Building the model int numIterations = 100; @@ -470,7 +471,7 @@ public class LinearRegression { } } ); - JavaRDD MSE = new JavaDoubleRDD(valuesAndPreds.map( + double MSE = new JavaDoubleRDD(valuesAndPreds.map( new Function, Object>() { public Object call(Tuple2 pair) { return Math.pow(pair._1() - pair._2(), 2.0); @@ -553,8 +554,8 @@ but in practice you will likely want to use unlabeled vectors for test data. {% highlight scala %} -val trainingData = ssc.textFileStream('/training/data/dir').map(LabeledPoint.parse) -val testData = ssc.textFileStream('/testing/data/dir').map(LabeledPoint.parse) +val trainingData = ssc.textFileStream("/training/data/dir").map(LabeledPoint.parse).cache() +val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse) {% endhighlight %} diff --git a/docs/mllib-optimization.md b/docs/mllib-optimization.md index 26ce5f3c501ff..45141c235be90 100644 --- a/docs/mllib-optimization.md +++ b/docs/mllib-optimization.md @@ -217,6 +217,7 @@ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLUtils import org.apache.spark.mllib.classification.LogisticRegressionModel +import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Updater} val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") val numFeatures = data.take(1)(0).features.size diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 9164c294ac7b8..e9f41758581e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -67,11 +67,13 @@ class PythonMLLibAPI extends Serializable { MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) private def trainRegressionModel( - trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, + learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector] - val model = trainFunc(data.rdd, initialWeights) + // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. + learner.disableUncachedWarning() + val model = learner.run(data.rdd, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() ret.add(SerDe.dumps(model.weights)) ret.add(model.intercept: java.lang.Double) @@ -106,8 +108,7 @@ class PythonMLLibAPI extends Serializable { + " Can only be initialized using the following string values: [l1, l2, none].") } trainRegressionModel( - (data, initialWeights) => - lrAlg.run(data, initialWeights), + lrAlg, data, initialWeightsBA) } @@ -122,15 +123,14 @@ class PythonMLLibAPI extends Serializable { regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + val lassoAlg = new LassoWithSGD() + lassoAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) trainRegressionModel( - (data, initialWeights) => - LassoWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + lassoAlg, data, initialWeightsBA) } @@ -145,15 +145,14 @@ class PythonMLLibAPI extends Serializable { regParam: Double, miniBatchFraction: Double, initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = { + val ridgeAlg = new RidgeRegressionWithSGD() + ridgeAlg.optimizer + .setNumIterations(numIterations) + .setRegParam(regParam) + .setStepSize(stepSize) + .setMiniBatchFraction(miniBatchFraction) trainRegressionModel( - (data, initialWeights) => - RidgeRegressionWithSGD.train( - data, - numIterations, - stepSize, - regParam, - miniBatchFraction, - initialWeights), + ridgeAlg, data, initialWeightsBA) } @@ -186,8 +185,7 @@ class PythonMLLibAPI extends Serializable { + " Can only be initialized using the following string values: [l1, l2, none].") } trainRegressionModel( - (data, initialWeights) => - SVMAlg.run(data, initialWeights), + SVMAlg, data, initialWeightsBA) } @@ -220,8 +218,7 @@ class PythonMLLibAPI extends Serializable { + " Can only be initialized using the following string values: [l1, l2, none].") } trainRegressionModel( - (data, initialWeights) => - LogRegAlg.run(data, initialWeights), + LogRegAlg, data, initialWeightsBA) } @@ -249,7 +246,14 @@ class PythonMLLibAPI extends Serializable { maxIterations: Int, runs: Int, initializationMode: String): KMeansModel = { - KMeans.train(data.rdd, k, maxIterations, runs, initializationMode) + val kMeansAlg = new KMeans() + .setK(k) + .setMaxIterations(maxIterations) + .setRuns(runs) + .setInitializationMode(initializationMode) + // Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD. + .disableUncachedWarning() + return kMeansAlg.run(data.rdd) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index fce8fe29f6e40..7443f232ec3e7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -27,6 +27,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom /** @@ -112,11 +113,26 @@ class KMeans private ( this } + /** Whether a warning should be logged if the input RDD is uncached. */ + private var warnOnUncachedInput = true + + /** Disable warnings about uncached input. */ + private[spark] def disableUncachedWarning(): this.type = { + warnOnUncachedInput = false + this + } + /** * Train a K-means model on the given set of points; `data` should be cached for high * performance, because this is an iterative algorithm. */ def run(data: RDD[Vector]): KMeansModel = { + + if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + // Compute squared norms and cache them. val norms = data.map(v => breezeNorm(v.toBreeze, 2.0)) norms.persist() @@ -125,6 +141,12 @@ class KMeans private ( } val model = runBreeze(breezeData) norms.unpersist() + + // Warn at the end of the run as well, for increased visibility. + if (warnOnUncachedInput && data.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } model } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 2e414a73be8e0..4174f45d231c7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} +import org.apache.spark.storage.StorageLevel /** * :: Experimental :: @@ -231,6 +232,10 @@ class RowMatrix( val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) case SVDMode.DistARPACK => + if (rows.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } require(k < n, s"k must be smaller than n in dist-eigs mode but got k=$k and n=$n.") EigenValueDecomposition.symmetricEigs(multiplyGramianMatrixBy, n, k, tol, maxIter) } @@ -256,6 +261,12 @@ class RowMatrix( logWarning(s"Requested $k singular values but only found $sk nonzeros.") } + // Warn at the end of the run as well, for increased visibility. + if (computeMode == SVDMode.DistARPACK && rows.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + val s = Vectors.dense(Arrays.copyOfRange(sigmas.data, 0, sk)) val V = Matrices.dense(n, sk, Arrays.copyOfRange(u.data, 0, n * sk)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 20c1fdd2269ce..d0fe4179685ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLUtils._ +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -133,6 +134,15 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] this } + /** Whether a warning should be logged if the input RDD is uncached. */ + private var warnOnUncachedInput = true + + /** Disable warnings about uncached input. */ + private[spark] def disableUncachedWarning(): this.type = { + warnOnUncachedInput = false + this + } + /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -149,6 +159,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Vector): M = { + if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data is not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + // Check the data properties before running the optimizer if (validateData && !validators.forall(func => func(input))) { throw new SparkException("Input validation failed.") @@ -223,6 +238,12 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] weights = scaler.transform(weights) } + // Warn at the end of the run as well, for increased visibility. + if (warnOnUncachedInput && input.getStorageLevel == StorageLevel.NONE) { + logWarning("The input data was not directly cached, which may hurt performance if its" + + " parent RDDs are also uncached.") + } + createModel(weights, intercept) } } From 0dc868e787a3bc69c1b8e90d916a6dcea8dbcd6d Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 25 Sep 2014 16:49:15 -0700 Subject: [PATCH 04/50] [SPARK-3584] sbin/slaves doesn't work when we use password authentication for SSH Author: Kousuke Saruta Closes #2444 from sarutak/slaves-scripts-modification and squashes the following commits: eff7394 [Kousuke Saruta] Improve the description about Cluster Launch Script in docs/spark-standalone.md 7858225 [Kousuke Saruta] Modified sbin/slaves to use the environment variable "SPARK_SSH_FOREGROUND" as a flag 53d7121 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into slaves-scripts-modification e570431 [Kousuke Saruta] Added a description for SPARK_SSH_FOREGROUND variable 7120a0c [Kousuke Saruta] Added a description about default host for sbin/slaves 1bba8a9 [Kousuke Saruta] Added SPARK_SSH_FOREGROUND flag to sbin/slaves 88e2f17 [Kousuke Saruta] Merge branch 'master' of git://git.apache.org/spark into slaves-scripts-modification 297e75d [Kousuke Saruta] Modified sbin/slaves not to export HOSTLIST --- .gitignore | 1 + .rat-excludes | 1 + conf/{slaves => slaves.template} | 0 docs/spark-standalone.md | 7 ++++++- sbin/slaves.sh | 31 ++++++++++++++++++++++--------- 5 files changed, 30 insertions(+), 10 deletions(-) rename conf/{slaves => slaves.template} (100%) diff --git a/.gitignore b/.gitignore index 7779980b74a22..34939e3a97aaa 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ conf/*.cmd conf/*.properties conf/*.conf conf/*.xml +conf/slaves docs/_site docs/api target/ diff --git a/.rat-excludes b/.rat-excludes index 9fc99d7fca35d..b14ad53720f32 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -19,6 +19,7 @@ log4j.properties log4j.properties.template metrics.properties.template slaves +slaves.template spark-env.sh spark-env.cmd spark-env.sh.template diff --git a/conf/slaves b/conf/slaves.template similarity index 100% rename from conf/slaves rename to conf/slaves.template diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 29b5491861bf3..58103fab20819 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -62,7 +62,12 @@ Finally, the following configuration options can be passed to the master and wor # Cluster Launch Scripts -To launch a Spark standalone cluster with the launch scripts, you need to create a file called `conf/slaves` in your Spark directory, which should contain the hostnames of all the machines where you would like to start Spark workers, one per line. The master machine must be able to access each of the slave machines via password-less `ssh` (using a private key). For testing, you can just put `localhost` in this file. +To launch a Spark standalone cluster with the launch scripts, you should create a file called conf/slaves in your Spark directory, +which must contain the hostnames of all the machines where you intend to start Spark workers, one per line. +If conf/slaves does not exist, the launch scripts defaults to a single machine (localhost), which is useful for testing. +Note, the master machine accesses each of the worker machines via ssh. By default, ssh is run in parallel and requires password-less (using a private key) access to be setup. +If you do not have a password-less setup, you can set the environment variable SPARK_SSH_FOREGROUND and serially provide a password for each worker. + Once you've set up this file, you can launch or stop your cluster with the following shell scripts, based on Hadoop's deploy scripts, and available in `SPARK_HOME/bin`: diff --git a/sbin/slaves.sh b/sbin/slaves.sh index 1d4dc5edf9858..cdad47ee2e594 100755 --- a/sbin/slaves.sh +++ b/sbin/slaves.sh @@ -44,7 +44,9 @@ sbin="`cd "$sbin"; pwd`" # If the slaves file is specified in the command line, # then it takes precedence over the definition in # spark-env.sh. Save it here. -HOSTLIST="$SPARK_SLAVES" +if [ -f "$SPARK_SLAVES" ]; then + HOSTLIST=`cat "$SPARK_SLAVES"` +fi # Check if --config is passed as an argument. It is an optional parameter. # Exit if the argument is not a directory. @@ -67,23 +69,34 @@ fi if [ "$HOSTLIST" = "" ]; then if [ "$SPARK_SLAVES" = "" ]; then - export HOSTLIST="${SPARK_CONF_DIR}/slaves" + if [ -f "${SPARK_CONF_DIR}/slaves" ]; then + HOSTLIST=`cat "${SPARK_CONF_DIR}/slaves"` + else + HOSTLIST=localhost + fi else - export HOSTLIST="${SPARK_SLAVES}" + HOSTLIST=`cat "${SPARK_SLAVES}"` fi fi + + # By default disable strict host key checking if [ "$SPARK_SSH_OPTS" = "" ]; then SPARK_SSH_OPTS="-o StrictHostKeyChecking=no" fi -for slave in `cat "$HOSTLIST"|sed "s/#.*$//;/^$/d"`; do - ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ - 2>&1 | sed "s/^/$slave: /" & - if [ "$SPARK_SLAVE_SLEEP" != "" ]; then - sleep $SPARK_SLAVE_SLEEP - fi +for slave in `echo "$HOSTLIST"|sed "s/#.*$//;/^$/d"`; do + if [ -n "${SPARK_SSH_FOREGROUND}" ]; then + ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ + 2>&1 | sed "s/^/$slave: /" + else + ssh $SPARK_SSH_OPTS "$slave" $"${@// /\\ }" \ + 2>&1 | sed "s/^/$slave: /" & + fi + if [ "$SPARK_SLAVE_SLEEP" != "" ]; then + sleep $SPARK_SLAVE_SLEEP + fi done wait From 86bce764983f2b14e1bd87fc3f4f938f7a217e1b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 25 Sep 2014 18:24:01 -0700 Subject: [PATCH 05/50] SPARK-2634: Change MapOutputTrackerWorker.mapStatuses to ConcurrentHashMap MapOutputTrackerWorker.mapStatuses is used concurrently, it should be thread-safe. This bug has already been fixed in #1328. Nevertheless, considering #1328 won't be merged soon, I send this trivial fix and hope this issue can be solved soon. Author: zsxwing Closes #1541 from zsxwing/SPARK-2634 and squashes the following commits: d450053 [zsxwing] SPARK-2634: Change MapOutputTrackerWorker.mapStatuses to ConcurrentHashMap --- .../main/scala/org/apache/spark/MapOutputTracker.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 51705c895a55c..f92189b707fb5 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,10 +18,12 @@ package org.apache.spark import java.io._ +import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.{HashSet, HashMap, Map} import scala.concurrent.Await +import scala.collection.JavaConversions._ import akka.actor._ import akka.pattern.ask @@ -84,6 +86,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks. * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the * master's corresponding HashMap. + * + * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a + * thread-safe map. */ protected val mapStatuses: Map[Int, Array[MapStatus]] @@ -339,7 +344,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) * MapOutputTrackerMaster. */ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) { - protected val mapStatuses = new HashMap[Int, Array[MapStatus]] + protected val mapStatuses: Map[Int, Array[MapStatus]] = + new ConcurrentHashMap[Int, Array[MapStatus]] } private[spark] object MapOutputTracker { From b235e013638685758885842dc3268e9800af3678 Mon Sep 17 00:00:00 2001 From: Hari Shreedharan Date: Thu, 25 Sep 2014 22:56:43 -0700 Subject: [PATCH 06/50] [SPARK-3686][STREAMING] Wait for sink to commit the channel before check... ...ing for the channel size. Author: Hari Shreedharan Closes #2531 from harishreedharan/sparksinksuite-fix and squashes the following commits: 30393c1 [Hari Shreedharan] Use more deterministic method to figure out when batches come in. 6ce9d8b [Hari Shreedharan] [SPARK-3686][STREAMING] Wait for sink to commit the channel before checking for the channel size. --- .../flume/sink/SparkAvroCallbackHandler.scala | 14 +++++++++++- .../streaming/flume/sink/SparkSink.scala | 10 +++++++++ .../flume/sink/TransactionProcessor.scala | 12 ++++++++++ .../streaming/flume/sink/SparkSinkSuite.scala | 22 +++++++++++-------- 4 files changed, 48 insertions(+), 10 deletions(-) diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index e77cf7bfa54d0..3c656a381bd9b 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.streaming.flume.sink -import java.util.concurrent.{ConcurrentHashMap, Executors} +import java.util.concurrent.{CountDownLatch, ConcurrentHashMap, Executors} import java.util.concurrent.atomic.AtomicLong import scala.collection.JavaConversions._ @@ -58,8 +58,12 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha private val seqBase = RandomStringUtils.randomAlphanumeric(8) private val seqCounter = new AtomicLong(0) + @volatile private var stopped = false + @volatile private var isTest = false + private var testLatch: CountDownLatch = null + /** * Returns a bunch of events to Spark over Avro RPC. * @param n Maximum number of events to return in a batch @@ -90,6 +94,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha val processor = new TransactionProcessor( channel, seq, n, transactionTimeout, backOffInterval, this) sequenceNumberToProcessor.put(seq, processor) + if (isTest) { + processor.countDownWhenBatchAcked(testLatch) + } Some(processor) } else { None @@ -141,6 +148,11 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha } } + private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { + testLatch = latch + isTest = true + } + /** * Shuts down the executor used to process transactions. */ diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 98ae7d783aec8..14dffb15fef98 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -138,6 +138,16 @@ class SparkSink extends AbstractSink with Logging with Configurable { throw new RuntimeException("Server was not started!") ) } + + /** + * Pass in a [[CountDownLatch]] for testing purposes. This batch is counted down when each + * batch is received. The test can simply call await on this latch till the expected number of + * batches are received. + * @param latch + */ + private[flume] def countdownWhenBatchReceived(latch: CountDownLatch) { + handler.foreach(_.countDownWhenBatchAcked(latch)) + } } /** diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index 13f3aa94be414..ea45b14294df9 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -62,6 +62,10 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, @volatile private var stopped = false + @volatile private var isTest = false + + private var testLatch: CountDownLatch = null + // The transaction that this processor would handle var txOpt: Option[Transaction] = None @@ -182,6 +186,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, rollbackAndClose(tx, close = false) // tx will be closed later anyway } finally { tx.close() + if (isTest) { + testLatch.countDown() + } } } else { logWarning("Spark could not commit transaction, NACK received. Rolling back transaction.") @@ -237,4 +244,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, processAckOrNack() null } + + private[sink] def countDownWhenBatchAcked(latch: CountDownLatch) { + testLatch = latch + isTest = true + } } diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 75a6668c6210b..a2b2cc6149d95 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -38,7 +38,7 @@ class SparkSinkSuite extends FunSuite { val channelCapacity = 5000 test("Success with ack") { - val (channel, sink) = initializeChannelAndSink() + val (channel, sink, latch) = initializeChannelAndSink() channel.start() sink.start() @@ -51,6 +51,7 @@ class SparkSinkSuite extends FunSuite { val events = client.getEventBatch(1000) client.ack(events.getSequenceNumber) assert(events.getEvents.size() === 1000) + latch.await(1, TimeUnit.SECONDS) assertChannelIsEmpty(channel) sink.stop() channel.stop() @@ -58,7 +59,7 @@ class SparkSinkSuite extends FunSuite { } test("Failure with nack") { - val (channel, sink) = initializeChannelAndSink() + val (channel, sink, latch) = initializeChannelAndSink() channel.start() sink.start() putEvents(channel, eventsPerBatch) @@ -70,6 +71,7 @@ class SparkSinkSuite extends FunSuite { val events = client.getEventBatch(1000) assert(events.getEvents.size() === 1000) client.nack(events.getSequenceNumber) + latch.await(1, TimeUnit.SECONDS) assert(availableChannelSlots(channel) === 4000) sink.stop() channel.stop() @@ -77,7 +79,7 @@ class SparkSinkSuite extends FunSuite { } test("Failure with timeout") { - val (channel, sink) = initializeChannelAndSink(Map(SparkSinkConfig + val (channel, sink, latch) = initializeChannelAndSink(Map(SparkSinkConfig .CONF_TRANSACTION_TIMEOUT -> 1.toString)) channel.start() sink.start() @@ -88,7 +90,7 @@ class SparkSinkSuite extends FunSuite { val (transceiver, client) = getTransceiverAndClient(address, 1)(0) val events = client.getEventBatch(1000) assert(events.getEvents.size() === 1000) - Thread.sleep(1000) + latch.await(1, TimeUnit.SECONDS) assert(availableChannelSlots(channel) === 4000) sink.stop() channel.stop() @@ -106,7 +108,7 @@ class SparkSinkSuite extends FunSuite { def testMultipleConsumers(failSome: Boolean): Unit = { implicit val executorContext = ExecutionContext .fromExecutorService(Executors.newFixedThreadPool(5)) - val (channel, sink) = initializeChannelAndSink() + val (channel, sink, latch) = initializeChannelAndSink(Map.empty, 5) channel.start() sink.start() (1 to 5).foreach(_ => putEvents(channel, eventsPerBatch)) @@ -136,7 +138,7 @@ class SparkSinkSuite extends FunSuite { } }) batchCounter.await() - TimeUnit.SECONDS.sleep(1) // Allow the sink to commit the transactions. + latch.await(1, TimeUnit.SECONDS) executorContext.shutdown() if(failSome) { assert(availableChannelSlots(channel) === 3000) @@ -148,8 +150,8 @@ class SparkSinkSuite extends FunSuite { transceiversAndClients.foreach(x => x._1.close()) } - private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty): (MemoryChannel, - SparkSink) = { + private def initializeChannelAndSink(overrides: Map[String, String] = Map.empty, + batchCounter: Int = 1): (MemoryChannel, SparkSink, CountDownLatch) = { val channel = new MemoryChannel() val channelContext = new Context() @@ -165,7 +167,9 @@ class SparkSinkSuite extends FunSuite { sinkContext.put(SparkSinkConfig.CONF_PORT, 0.toString) sink.configure(sinkContext) sink.setChannel(channel) - (channel, sink) + val latch = new CountDownLatch(batchCounter) + sink.countdownWhenBatchReceived(latch) + (channel, sink, latch) } private def putEvents(ch: MemoryChannel, count: Int): Unit = { From 1aa549ba9839565274a12c52fa1075b424f138a6 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 26 Sep 2014 09:27:42 -0700 Subject: [PATCH 07/50] [SPARK-3478] [PySpark] Profile the Python tasks This patch add profiling support for PySpark, it will show the profiling results before the driver exits, here is one example: ``` ============================================================ Profile of RDD ============================================================ 5146507 function calls (5146487 primitive calls) in 71.094 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 5144576 68.331 0.000 68.331 0.000 statcounter.py:44(merge) 20 2.735 0.137 71.071 3.554 statcounter.py:33(__init__) 20 0.017 0.001 0.017 0.001 {cPickle.dumps} 1024 0.003 0.000 0.003 0.000 t.py:16() 20 0.001 0.000 0.001 0.000 {reduce} 21 0.001 0.000 0.001 0.000 {cPickle.loads} 20 0.001 0.000 0.001 0.000 copy_reg.py:95(_slotnames) 41 0.001 0.000 0.001 0.000 serializers.py:461(read_int) 40 0.001 0.000 0.002 0.000 serializers.py:179(_batched) 62 0.000 0.000 0.000 0.000 {method 'read' of 'file' objects} 20 0.000 0.000 71.072 3.554 rdd.py:863() 20 0.000 0.000 0.001 0.000 serializers.py:198(load_stream) 40/20 0.000 0.000 71.072 3.554 rdd.py:2093(pipeline_func) 41 0.000 0.000 0.002 0.000 serializers.py:130(load_stream) 40 0.000 0.000 71.072 1.777 rdd.py:304(func) 20 0.000 0.000 71.094 3.555 worker.py:82(process) ``` Also, use can show profile result manually by `sc.show_profiles()` or dump it into disk by `sc.dump_profiles(path)`, such as ```python >>> sc._conf.set("spark.python.profile", "true") >>> rdd = sc.parallelize(range(100)).map(str) >>> rdd.count() 100 >>> sc.show_profiles() ============================================================ Profile of RDD ============================================================ 284 function calls (276 primitive calls) in 0.001 seconds Ordered by: internal time, cumulative time ncalls tottime percall cumtime percall filename:lineno(function) 4 0.000 0.000 0.000 0.000 serializers.py:198(load_stream) 4 0.000 0.000 0.000 0.000 {reduce} 12/4 0.000 0.000 0.001 0.000 rdd.py:2092(pipeline_func) 4 0.000 0.000 0.000 0.000 {cPickle.loads} 4 0.000 0.000 0.000 0.000 {cPickle.dumps} 104 0.000 0.000 0.000 0.000 rdd.py:852() 8 0.000 0.000 0.000 0.000 serializers.py:461(read_int) 12 0.000 0.000 0.000 0.000 rdd.py:303(func) ``` The profiling is disabled by default, can be enabled by "spark.python.profile=true". Also, users can dump the results into disks automatically for future analysis, by "spark.python.profile.dump=path_to_dump" Author: Davies Liu Closes #2351 from davies/profiler and squashes the following commits: 7ef2aa0 [Davies Liu] bugfix, add tests for show_profiles and dump_profiles() 2b0daf2 [Davies Liu] fix docs 7a56c24 [Davies Liu] bugfix cba9463 [Davies Liu] move show_profiles and dump_profiles to SparkContext fb9565b [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 116d52a [Davies Liu] Merge branch 'master' of github.com:apache/spark into profiler 09d02c3 [Davies Liu] Merge branch 'master' into profiler c23865c [Davies Liu] Merge branch 'master' into profiler 15d6f18 [Davies Liu] add docs for two configs dadee1a [Davies Liu] add docs string and clear profiles after show or dump 4f8309d [Davies Liu] address comment, add tests 0a5b6eb [Davies Liu] fix Python UDF 4b20494 [Davies Liu] add profile for python --- docs/configuration.md | 19 +++++++++++++++++ python/pyspark/accumulators.py | 15 +++++++++++++ python/pyspark/context.py | 39 +++++++++++++++++++++++++++++++++- python/pyspark/rdd.py | 10 +++++++-- python/pyspark/sql.py | 2 +- python/pyspark/tests.py | 30 ++++++++++++++++++++++++++ python/pyspark/worker.py | 19 ++++++++++++++--- 7 files changed, 127 insertions(+), 7 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index a6dd7245e1552..791b6f2aa3261 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,6 +206,25 @@ Apart from these, the following properties are also available, and may be useful used during aggregation goes above this amount, it will spill the data into disks. + + spark.python.profile + false + + Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, + or it will be displayed before the driver exiting. It also can be dumped into disk by + `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, + they will not be displayed automatically before driver exiting. + + + + spark.python.profile.dump + (none) + + The directory which is used to dump the profile result before driver exiting. + The results will be dumped as separated file for each RDD. They can be loaded + by ptats.Stats(). If this is specified, the profile result will not be displayed + automatically. + spark.python.worker.reuse true diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index ccbca67656c8d..b8cdbbe3cf2b6 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,6 +215,21 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) +class PStatsParam(AccumulatorParam): + """PStatsParam is used to merge pstats.Stats""" + + @staticmethod + def zero(value): + return None + + @staticmethod + def addInPlace(value1, value2): + if value1 is None: + return value2 + value1.add(value2) + return value1 + + class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 8e7b00469e246..abeda19b77d8b 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,6 +20,7 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile +import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -30,7 +31,6 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel -from pyspark import rdd from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -192,6 +192,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._temp_dir = \ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() + # profiling stats collected for each PythonRDD + self._profile_stats = [] + def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization @@ -792,6 +795,40 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) + def _add_profile(self, id, profileAcc): + if not self._profile_stats: + dump_path = self._conf.get("spark.python.profile.dump") + if dump_path: + atexit.register(self.dump_profiles, dump_path) + else: + atexit.register(self.show_profiles) + + self._profile_stats.append([id, profileAcc, False]) + + def show_profiles(self): + """ Print the profile stats to stdout """ + for i, (id, acc, showed) in enumerate(self._profile_stats): + stats = acc.value + if not showed and stats: + print "=" * 60 + print "Profile of RDD" % id + print "=" * 60 + stats.sort_stats("tottime", "cumtime").print_stats() + # mark it as showed + self._profile_stats[i][2] = True + + def dump_profiles(self, path): + """ Dump the profile stats into directory `path` + """ + if not os.path.exists(path): + os.makedirs(path) + for id, acc, _ in self._profile_stats: + stats = acc.value + if stats: + p = os.path.join(path, "rdd_%d.pstats" % id) + stats.dump_stats(p) + self._profile_stats = [] + def _test(): import atexit diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 680140d72d03c..8ed89e2f9769f 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -15,7 +15,6 @@ # limitations under the License. # -from base64 import standard_b64encode as b64enc import copy from collections import defaultdict from itertools import chain, ifilter, imap @@ -32,6 +31,7 @@ from random import Random from math import sqrt, log, isinf, isnan +from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -2080,7 +2080,9 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - command = (self.func, self._prev_jrdd_deserializer, + enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" + profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None + command = (self.func, profileStats, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() @@ -2102,6 +2104,10 @@ def _jrdd(self): self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() + + if enable_profile: + self._id = self._jrdd_val.id() + self.ctx._add_profile(self._id, profileStats) return self._jrdd_val def id(self): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 653195ea438cf..ee5bda8bb43d5 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -974,7 +974,7 @@ def registerFunction(self, name, f, returnType=StringType()): [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, + command = (func, None, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) ser = CloudPickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index d1bb2033b7a16..e6002afa9c70d 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -632,6 +632,36 @@ def test_distinct(self): self.assertEquals(result.count(), 3) +class TestProfiler(PySparkTestCase): + + def setUp(self): + self._old_sys_path = list(sys.path) + class_name = self.__class__.__name__ + conf = SparkConf().set("spark.python.profile", "true") + self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) + + def test_profiler(self): + + def heavy_foo(x): + for i in range(1 << 20): + x = 1 + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + profiles = self.sc._profile_stats + self.assertEqual(1, len(profiles)) + id, acc, _ = profiles[0] + stats = acc.value + self.assertTrue(stats is not None) + width, stat_list = stats.get_print_list([]) + func_names = [func_name for fname, n, func_name in stat_list] + self.assertTrue("heavy_foo" in func_names) + + self.sc.show_profiles() + d = tempfile.gettempdir() + self.sc.dump_profiles(d) + self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + + class TestSQL(PySparkTestCase): def setUp(self): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index c1f6e3e4a1f40..8257dddfee1c3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,6 +23,8 @@ import time import socket import traceback +import cProfile +import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -90,10 +92,21 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, deserializer, serializer) = command + (func, stats, deserializer, serializer) = command init_time = time.time() - iterator = deserializer.load_stream(infile) - serializer.dump_stream(func(split_index, iterator), outfile) + + def process(): + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) + + if stats: + p = cProfile.Profile() + p.runcall(process) + st = pstats.Stats(p) + st.stream = None # make it picklable + stats.add(st.strip_dirs()) + else: + process() except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) From d16e161d744b27291fd2ee7e3578917ee14d83f9 Mon Sep 17 00:00:00 2001 From: aniketbhatnagar Date: Fri, 26 Sep 2014 09:47:58 -0700 Subject: [PATCH 08/50] SPARK-3639 | Removed settings master in examples This patch removes setting of master as local in Kinesis examples so that users can set it using submit-job. Author: aniketbhatnagar Closes #2536 from aniketbhatnagar/Kinesis-Examples-Master-Unset and squashes the following commits: c9723ac [aniketbhatnagar] Merge remote-tracking branch 'origin/Kinesis-Examples-Master-Unset' into Kinesis-Examples-Master-Unset fec8ead [aniketbhatnagar] SPARK-3639 | Removed settings master in examples 31cdc59 [aniketbhatnagar] SPARK-3639 | Removed settings master in examples --- .../examples/streaming/JavaKinesisWordCountASL.java | 9 ++++----- .../examples/streaming/KinesisWordCountASL.scala | 13 +++++-------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java index aa917d0575c4c..b0bff27a61c19 100644 --- a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -71,6 +71,9 @@ * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ * https://kinesis.us-east-1.amazonaws.com * + * Note that number of workers/threads should be 1 more than the number of receivers. + * This leaves one thread available for actually processing the data. + * * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data * onto the Kinesis stream. * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. @@ -114,12 +117,8 @@ public static void main(String[] args) { /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ int numStreams = numShards; - /* Must add 1 more thread than the number of receivers or the output won't show properly from the driver */ - int numSparkThreads = numStreams + 1; - /* Setup the Spark config. */ - SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount").setMaster( - "local[" + numSparkThreads + "]"); + SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount"); /* Kinesis checkpoint interval. Same as batchInterval for this example. */ Duration checkpointInterval = batchInterval; diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala index fffd90de08240..32da0858d1a1d 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -65,6 +65,10 @@ import org.apache.log4j.Level * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ * https://kinesis.us-east-1.amazonaws.com * + * + * Note that number of workers/threads should be 1 more than the number of receivers. + * This leaves one thread available for actually processing the data. + * * There is a companion helper class below called KinesisWordCountProducerASL which puts * dummy data onto the Kinesis stream. * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. @@ -97,17 +101,10 @@ private object KinesisWordCountASL extends Logging { /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ val numStreams = numShards - /* - * numSparkThreads should be 1 more thread than the number of receivers. - * This leaves one thread available for actually processing the data. - */ - val numSparkThreads = numStreams + 1 - /* Setup the and SparkConfig and StreamingContext */ /* Spark Streaming batch interval */ - val batchInterval = Milliseconds(2000) + val batchInterval = Milliseconds(2000) val sparkConfig = new SparkConf().setAppName("KinesisWordCount") - .setMaster(s"local[$numSparkThreads]") val ssc = new StreamingContext(sparkConfig, batchInterval) /* Kinesis checkpoint interval. Same as batchInterval for this example. */ From ec9df6a765701fa41390083df12e1dc1fee50662 Mon Sep 17 00:00:00 2001 From: RJ Nowling Date: Fri, 26 Sep 2014 09:58:47 -0700 Subject: [PATCH 09/50] [SPARK-3614][MLLIB] Add minimumOccurence filtering to IDF This PR for [SPARK-3614](https://issues.apache.org/jira/browse/SPARK-3614) adds functionality for filtering out terms which do not appear in at least a minimum number of documents. This is implemented using a minimumOccurence parameter (default 0). When terms' document frequencies are less than minimumOccurence, their IDFs are set to 0, just like when the DF is 0. As a result, the TF-IDFs for the terms are found to be 0, as if the terms were not present in the documents. This PR makes the following changes: * Add a minimumOccurence parameter to the IDF and DocumentFrequencyAggregator classes. * Create a parameter-less constructor for IDF with a default minimumOccurence value of 0 to remain backwards-compatibility with the original IDF API. * Sets the IDFs to 0 for terms which DFs are less than minimumOccurence * Add tests to the Spark IDFSuite and Java JavaTfIdfSuite test suites * Updated the MLLib Feature Extraction programming guide to describe the new feature Author: RJ Nowling Closes #2494 from rnowling/spark-3614-idf-filter and squashes the following commits: 0aa3c63 [RJ Nowling] Fix identation e6523a8 [RJ Nowling] Remove unnecessary toDouble's from IDFSuite bfa82ec [RJ Nowling] Add space after if 30d20b3 [RJ Nowling] Add spaces around equals signs 9013447 [RJ Nowling] Add space before division operator 79978fc [RJ Nowling] Remove unnecessary semi-colon 40fd70c [RJ Nowling] Change minimumOccurence to minDocFreq in code and docs 47850ab [RJ Nowling] Changed minimumOccurence to Int from Long 9fb4093 [RJ Nowling] Remove unnecessary lines from IDF class docs 1fc09d8 [RJ Nowling] Add backwards-compatible constructor to DocumentFrequencyAggregator 1801fd2 [RJ Nowling] Fix style errors in IDF.scala 6897252 [RJ Nowling] Preface minimumOccurence members with val to make them final and immutable a200bab [RJ Nowling] Remove unnecessary else statement 4b974f5 [RJ Nowling] Remove accidentally-added import from testing c0cc643 [RJ Nowling] Add minimumOccurence filtering to IDF --- docs/mllib-feature-extraction.md | 15 ++++++++ .../org/apache/spark/mllib/feature/IDF.scala | 37 +++++++++++++++++-- .../spark/mllib/feature/JavaTfIdfSuite.java | 20 ++++++++++ .../apache/spark/mllib/feature/IDFSuite.scala | 36 +++++++++++++++++- 4 files changed, 103 insertions(+), 5 deletions(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 41a27f6208d1b..1511ae6dda4ed 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -82,6 +82,21 @@ tf.cache() val idf = new IDF().fit(tf) val tfidf: RDD[Vector] = idf.transform(tf) {% endhighlight %} + +MLLib's IDF implementation provides an option for ignoring terms which occur in less than a +minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature +can be used by passing the `minDocFreq` value to the IDF constructor. + +{% highlight scala %} +import org.apache.spark.mllib.feature.IDF + +// ... continue from the previous example +tf.cache() +val idf = new IDF(minDocFreq = 2).fit(tf) +val tfidf: RDD[Vector] = idf.transform(tf) +{% endhighlight %} + + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index d40d5553c1d21..720bb70b08dbf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -30,9 +30,18 @@ import org.apache.spark.rdd.RDD * Inverse document frequency (IDF). * The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total * number of documents and `d(t)` is the number of documents that contain term `t`. + * + * This implementation supports filtering out terms which do not appear in a minimum number + * of documents (controlled by the variable `minDocFreq`). For terms that are not in + * at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0. + * + * @param minDocFreq minimum of documents in which a term + * should appear for filtering */ @Experimental -class IDF { +class IDF(val minDocFreq: Int) { + + def this() = this(0) // TODO: Allow different IDF formulations. @@ -41,7 +50,8 @@ class IDF { * @param dataset an RDD of term frequency vectors */ def fit(dataset: RDD[Vector]): IDFModel = { - val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)( + val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator( + minDocFreq = minDocFreq))( seqOp = (df, v) => df.add(v), combOp = (df1, df2) => df1.merge(df2) ).idf() @@ -60,13 +70,16 @@ class IDF { private object IDF { /** Document frequency aggregator. */ - class DocumentFrequencyAggregator extends Serializable { + class DocumentFrequencyAggregator(val minDocFreq: Int) extends Serializable { /** number of documents */ private var m = 0L /** document frequency vector */ private var df: BDV[Long] = _ + + def this() = this(0) + /** Adds a new document. */ def add(doc: Vector): this.type = { if (isEmpty) { @@ -123,7 +136,18 @@ private object IDF { val inv = new Array[Double](n) var j = 0 while (j < n) { - inv(j) = math.log((m + 1.0)/ (df(j) + 1.0)) + /* + * If the term is not present in the minimum + * number of documents, set IDF to 0. This + * will cause multiplication in IDFModel to + * set TF-IDF to 0. + * + * Since arrays are initialized to 0 by default, + * we just omit changing those entries. + */ + if(df(j) >= minDocFreq) { + inv(j) = math.log((m + 1.0) / (df(j) + 1.0)) + } j += 1 } Vectors.dense(inv) @@ -140,6 +164,11 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable { /** * Transforms term frequency (TF) vectors to TF-IDF vectors. + * + * If `minDocFreq` was set for the IDF calculation, + * the terms which occur in fewer than `minDocFreq` + * documents will have an entry of 0. + * * @param dataset an RDD of term frequency vectors * @return an RDD of TF-IDF vectors */ diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java index e8d99f4ae43ae..064263e02cd11 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java @@ -63,4 +63,24 @@ public void tfIdf() { Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); } } + + @Test + public void tfIdfMinimumDocumentFrequency() { + // The tests are to check Java compatibility. + HashingTF tf = new HashingTF(); + JavaRDD> documents = sc.parallelize(Lists.newArrayList( + Lists.newArrayList("this is a sentence".split(" ")), + Lists.newArrayList("this is another sentence".split(" ")), + Lists.newArrayList("this is still a sentence".split(" "))), 2); + JavaRDD termFreqs = tf.transform(documents); + termFreqs.collect(); + IDF idf = new IDF(2); + JavaRDD tfIdfs = idf.fit(termFreqs).transform(termFreqs); + List localTfIdfs = tfIdfs.collect(); + int indexOfThis = tf.indexOf("this"); + for (Vector v: localTfIdfs) { + Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15); + } + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 53d9c0c640b98..43974f84e3ca8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -38,7 +38,7 @@ class IDFSuite extends FunSuite with LocalSparkContext { val idf = new IDF val model = idf.fit(termFrequencies) val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => - math.log((m.toDouble + 1.0) / (x + 1.0)) + math.log((m + 1.0) / (x + 1.0)) }) assert(model.idf ~== expected absTol 1e-12) val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() @@ -54,4 +54,38 @@ class IDFSuite extends FunSuite with LocalSparkContext { assert(tfidf2.indices === Array(1)) assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) } + + test("idf minimum document frequency filtering") { + val n = 4 + val localTermFrequencies = Seq( + Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)), + Vectors.dense(0.0, 1.0, 2.0, 3.0), + Vectors.sparse(n, Array(1), Array(1.0)) + ) + val m = localTermFrequencies.size + val termFrequencies = sc.parallelize(localTermFrequencies, 2) + val idf = new IDF(minDocFreq = 1) + val model = idf.fit(termFrequencies) + val expected = Vectors.dense(Array(0, 3, 1, 2).map { x => + if (x > 0) { + math.log((m + 1.0) / (x + 1.0)) + } else { + 0 + } + }) + assert(model.idf ~== expected absTol 1e-12) + val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap() + assert(tfidf.size === 3) + val tfidf0 = tfidf(0L).asInstanceOf[SparseVector] + assert(tfidf0.indices === Array(1, 3)) + assert(Vectors.dense(tfidf0.values) ~== + Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12) + val tfidf1 = tfidf(1L).asInstanceOf[DenseVector] + assert(Vectors.dense(tfidf1.values) ~== + Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12) + val tfidf2 = tfidf(2L).asInstanceOf[SparseVector] + assert(tfidf2.indices === Array(1)) + assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12) + } + } From 30461c6ac3dcfb05dc1891494ec161601c0fb59f Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 26 Sep 2014 11:26:53 -0700 Subject: [PATCH 10/50] [SPARK-3695]shuffle fetch fail output should output detailed host and port in error message Author: Daoyuan Wang Closes #2539 from adrian-wang/fetchfail and squashes the following commits: 6c1b1e0 [Daoyuan Wang] shuffle fetch fail output --- .../org/apache/spark/storage/ShuffleBlockFetcherIterator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d868758a7f549..71b276b5f18e4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -121,7 +121,7 @@ final class ShuffleBlockFetcherIterator( } override def onBlockFetchFailure(e: Throwable): Unit = { - logError("Failed to get block(s) from ${req.address.host}:${req.address.port}", e) + logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) // Note that there is a chance that some blocks have been fetched successfully, but we // still add them to the failed queue. This is fine because when the caller see a // FetchFailedException, it is going to fail the entire task anyway. From 8da10bf14660f1d5b1dab692cb56b9832ab10d40 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 26 Sep 2014 11:50:48 -0700 Subject: [PATCH 11/50] [SPARK-3476] Remove outdated memory checks in Yarn See description in [JIRA](https://issues.apache.org/jira/browse/SPARK-3476). Author: Andrew Or Closes #2528 from andrewor14/yarn-memory-checks and squashes the following commits: c5400cd [Andrew Or] Simplify checks e30ffac [Andrew Or] Remove outdated memory checks --- .../apache/spark/deploy/yarn/ClientArguments.scala | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 201b742736c6e..26dbd6237c6b8 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -69,16 +69,9 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf) * This is intended to be called only after the provided arguments have been parsed. */ private def validateArgs(): Unit = { - // TODO: memory checks are outdated (SPARK-3476) - Map[Boolean, String]( - (numExecutors <= 0) -> "You must specify at least 1 executor!", - (amMemory <= amMemoryOverhead) -> s"AM memory must be > $amMemoryOverhead MB", - (executorMemory <= executorMemoryOverhead) -> - s"Executor memory must be > $executorMemoryOverhead MB" - ).foreach { case (errorCondition, errorMessage) => - if (errorCondition) { - throw new IllegalArgumentException(errorMessage + "\n" + getUsageMessage()) - } + if (numExecutors <= 0) { + throw new IllegalArgumentException( + "You must specify at least 1 executor!\n" + getUsageMessage()) } } From 0ec2d2e8f0c0dc61a7ed6377898846661d2424cd Mon Sep 17 00:00:00 2001 From: Daoyuan Wang Date: Fri, 26 Sep 2014 12:04:37 -0700 Subject: [PATCH 12/50] [SPARK-3531][SQL]select null from table would throw a MatchError Author: Daoyuan Wang Closes #2396 from adrian-wang/selectnull and squashes the following commits: 2458229 [Daoyuan Wang] rebase solution --- .../scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala | 1 + .../select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 | 1 + .../org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 3 +++ 3 files changed, 5 insertions(+) create mode 100644 sql/hive/src/test/resources/golden/select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 6b4399e852c7b..9a0b9b46ac4ee 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -244,6 +244,7 @@ object HiveMetastoreTypes extends RegexParsers { case BooleanType => "boolean" case DecimalType => "decimal" case TimestampType => "timestamp" + case NullType => "void" } } diff --git a/sql/hive/src/test/resources/golden/select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 b/sql/hive/src/test/resources/golden/select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 new file mode 100644 index 0000000000000..7951defec192a --- /dev/null +++ b/sql/hive/src/test/resources/golden/select null from table-0-5bb53cca754cc8afe9cd22feb8c586d1 @@ -0,0 +1 @@ +NULL diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 426f5fcee6157..2f876cafaf218 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -330,6 +330,9 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("timestamp cast #8", "SELECT CAST(CAST(-1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + createQueryTest("select null from table", + "SELECT null FROM src LIMIT 1") + test("implement identity function using case statement") { val actual = sql("SELECT (CASE key WHEN key THEN key END) FROM src") .map { case Row(i: Int) => i } From 7364fa5a176da69e425bca0e3e137ee73275c78c Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Fri, 26 Sep 2014 12:06:01 -0700 Subject: [PATCH 13/50] [SPARK-3393] [SQL] Align the log4j configuration for Spark & SparkSQLCLI User may be confused for the HQL logging & configurations, we'd better provide a default templates. Both files are copied from Hive. Author: Cheng Hao Closes #2263 from chenghao-intel/hive_template and squashes the following commits: 53bffa9 [Cheng Hao] Remove the hive-log4j.properties initialization --- .../hive/thriftserver/SparkSQLCLIDriver.scala | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index b092f42372171..7ba4564602ecd 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -73,18 +73,6 @@ private[hive] object SparkSQLCLIDriver { System.exit(1) } - // NOTE: It is critical to do this here so that log4j is reinitialized - // before any of the other core hive classes are loaded - var logInitFailed = false - var logInitDetailMessage: String = null - try { - logInitDetailMessage = LogUtils.initHiveLog4j() - } catch { - case e: LogInitializationException => - logInitFailed = true - logInitDetailMessage = e.getMessage - } - val sessionState = new CliSessionState(new HiveConf(classOf[SessionState])) sessionState.in = System.in @@ -100,11 +88,6 @@ private[hive] object SparkSQLCLIDriver { System.exit(2) } - if (!sessionState.getIsSilent) { - if (logInitFailed) System.err.println(logInitDetailMessage) - else SessionState.getConsole.printInfo(logInitDetailMessage) - } - // Set all properties specified via command line. val conf: HiveConf = sessionState.getConf sessionState.cmdProperties.entrySet().foreach { item: java.util.Map.Entry[Object, Object] => From f872e4fb80b8429800daa9c44c0cac620c1ff303 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 26 Sep 2014 14:47:14 -0700 Subject: [PATCH 14/50] Revert "[SPARK-3478] [PySpark] Profile the Python tasks" This reverts commit 1aa549ba9839565274a12c52fa1075b424f138a6. --- docs/configuration.md | 19 ----------------- python/pyspark/accumulators.py | 15 ------------- python/pyspark/context.py | 39 +--------------------------------- python/pyspark/rdd.py | 10 ++------- python/pyspark/sql.py | 2 +- python/pyspark/tests.py | 30 -------------------------- python/pyspark/worker.py | 19 +++-------------- 7 files changed, 7 insertions(+), 127 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index 791b6f2aa3261..a6dd7245e1552 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,25 +206,6 @@ Apart from these, the following properties are also available, and may be useful used during aggregation goes above this amount, it will spill the data into disks. - - spark.python.profile - false - - Enable profiling in Python worker, the profile result will show up by `sc.show_profiles()`, - or it will be displayed before the driver exiting. It also can be dumped into disk by - `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, - they will not be displayed automatically before driver exiting. - - - - spark.python.profile.dump - (none) - - The directory which is used to dump the profile result before driver exiting. - The results will be dumped as separated file for each RDD. They can be loaded - by ptats.Stats(). If this is specified, the profile result will not be displayed - automatically. - spark.python.worker.reuse true diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index b8cdbbe3cf2b6..ccbca67656c8d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,21 +215,6 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) -class PStatsParam(AccumulatorParam): - """PStatsParam is used to merge pstats.Stats""" - - @staticmethod - def zero(value): - return None - - @staticmethod - def addInPlace(value1, value2): - if value1 is None: - return value2 - value1.add(value2) - return value1 - - class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index abeda19b77d8b..8e7b00469e246 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -31,6 +30,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, CompressedSerializer from pyspark.storagelevel import StorageLevel +from pyspark import rdd from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call @@ -192,9 +192,6 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._temp_dir = \ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() - # profiling stats collected for each PythonRDD - self._profile_stats = [] - def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization @@ -795,40 +792,6 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) - def _add_profile(self, id, profileAcc): - if not self._profile_stats: - dump_path = self._conf.get("spark.python.profile.dump") - if dump_path: - atexit.register(self.dump_profiles, dump_path) - else: - atexit.register(self.show_profiles) - - self._profile_stats.append([id, profileAcc, False]) - - def show_profiles(self): - """ Print the profile stats to stdout """ - for i, (id, acc, showed) in enumerate(self._profile_stats): - stats = acc.value - if not showed and stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 - stats.sort_stats("tottime", "cumtime").print_stats() - # mark it as showed - self._profile_stats[i][2] = True - - def dump_profiles(self, path): - """ Dump the profile stats into directory `path` - """ - if not os.path.exists(path): - os.makedirs(path) - for id, acc, _ in self._profile_stats: - stats = acc.value - if stats: - p = os.path.join(path, "rdd_%d.pstats" % id) - stats.dump_stats(p) - self._profile_stats = [] - def _test(): import atexit diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8ed89e2f9769f..680140d72d03c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -15,6 +15,7 @@ # limitations under the License. # +from base64 import standard_b64encode as b64enc import copy from collections import defaultdict from itertools import chain, ifilter, imap @@ -31,7 +32,6 @@ from random import Random from math import sqrt, log, isinf, isnan -from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -2080,9 +2080,7 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" - profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None - command = (self.func, profileStats, self._prev_jrdd_deserializer, + command = (self.func, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() @@ -2104,10 +2102,6 @@ def _jrdd(self): self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() - - if enable_profile: - self._id = self._jrdd_val.id() - self.ctx._add_profile(self._id, profileStats) return self._jrdd_val def id(self): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ee5bda8bb43d5..653195ea438cf 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -974,7 +974,7 @@ def registerFunction(self, name, f, returnType=StringType()): [Row(c0=4)] """ func = lambda _, it: imap(lambda x: f(*x), it) - command = (func, None, + command = (func, BatchedSerializer(PickleSerializer(), 1024), BatchedSerializer(PickleSerializer(), 1024)) ser = CloudPickleSerializer() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e6002afa9c70d..d1bb2033b7a16 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -632,36 +632,6 @@ def test_distinct(self): self.assertEquals(result.count(), 3) -class TestProfiler(PySparkTestCase): - - def setUp(self): - self._old_sys_path = list(sys.path) - class_name = self.__class__.__name__ - conf = SparkConf().set("spark.python.profile", "true") - self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf) - - def test_profiler(self): - - def heavy_foo(x): - for i in range(1 << 20): - x = 1 - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - profiles = self.sc._profile_stats - self.assertEqual(1, len(profiles)) - id, acc, _ = profiles[0] - stats = acc.value - self.assertTrue(stats is not None) - width, stat_list = stats.get_print_list([]) - func_names = [func_name for fname, n, func_name in stat_list] - self.assertTrue("heavy_foo" in func_names) - - self.sc.show_profiles() - d = tempfile.gettempdir() - self.sc.dump_profiles(d) - self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) - - class TestSQL(PySparkTestCase): def setUp(self): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 8257dddfee1c3..c1f6e3e4a1f40 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,8 +23,6 @@ import time import socket import traceback -import cProfile -import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -92,21 +90,10 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, stats, deserializer, serializer) = command + (func, deserializer, serializer) = command init_time = time.time() - - def process(): - iterator = deserializer.load_stream(infile) - serializer.dump_stream(func(split_index, iterator), outfile) - - if stats: - p = cProfile.Profile() - p.runcall(process) - st = pstats.Stats(p) - st.stream = None # make it picklable - stats.add(st.strip_dirs()) - else: - process() + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) except Exception: try: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) From 5e34855cf04145cc3b7bae996c2a6e668f144a11 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Fri, 26 Sep 2014 21:29:54 -0700 Subject: [PATCH 15/50] [SPARK-3543] Write TaskContext in Java and expose it through a static accessor. Author: Prashant Sharma Author: Shashank Sharma Closes #2425 from ScrapCodes/SPARK-3543/withTaskContext and squashes the following commits: 8ae414c [Shashank Sharma] CR ee8bd00 [Prashant Sharma] Added internal API in docs comments. ddb8cbe [Prashant Sharma] Moved setting the thread local to where TaskContext is instantiated. a7d5e23 [Prashant Sharma] Added doc comments. edf945e [Prashant Sharma] Code review git add -A f716fd1 [Prashant Sharma] introduced thread local for getting the task context. 333c7d6 [Prashant Sharma] Translated Task context from scala to java. --- .../java/org/apache/spark/TaskContext.java | 274 ++++++++++++++++++ .../scala/org/apache/spark/TaskContext.scala | 126 -------- .../main/scala/org/apache/spark/rdd/RDD.scala | 1 + .../apache/spark/scheduler/DAGScheduler.scala | 4 +- .../org/apache/spark/scheduler/Task.scala | 6 +- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../org/apache/spark/CacheManagerSuite.scala | 2 +- 7 files changed, 284 insertions(+), 131 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/TaskContext.java delete mode 100644 core/src/main/scala/org/apache/spark/TaskContext.scala diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java new file mode 100644 index 0000000000000..09b8ce02bd3d8 --- /dev/null +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import scala.Function0; +import scala.Function1; +import scala.Unit; +import scala.collection.JavaConversions; + +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.util.TaskCompletionListener; +import org.apache.spark.util.TaskCompletionListenerException; + +/** +* :: DeveloperApi :: +* Contextual information about a task which can be read or mutated during execution. +*/ +@DeveloperApi +public class TaskContext implements Serializable { + + private int stageId; + private int partitionId; + private long attemptId; + private boolean runningLocally; + private TaskMetrics taskMetrics; + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + * @param taskMetrics performance metrics of the task + */ + @DeveloperApi + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally, + TaskMetrics taskMetrics) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = taskMetrics; + } + + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + */ + @DeveloperApi + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, + Boolean runningLocally) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + } + + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + */ + @DeveloperApi + public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = false; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + } + + private static ThreadLocal taskContext = + new ThreadLocal(); + + /** + * :: Internal API :: + * This is spark internal API, not intended to be called from user programs. + */ + public static void setTaskContext(TaskContext tc) { + taskContext.set(tc); + } + + public static TaskContext get() { + return taskContext.get(); + } + + /** + * :: Internal API :: + */ + public static void remove() { + taskContext.remove(); + } + + // List of callback functions to execute when the task completes. + private transient List onCompleteCallbacks = + new ArrayList(); + + // Whether the corresponding task has been killed. + private volatile Boolean interrupted = false; + + // Whether the task has completed. + private volatile Boolean completed = false; + + /** + * Checks whether the task has completed. + */ + public Boolean isCompleted() { + return completed; + } + + /** + * Checks whether the task has been killed. + */ + public Boolean isInterrupted() { + return interrupted; + } + + /** + * Add a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + *

+ * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { + onCompleteCallbacks.add(listener); + return this; + } + + /** + * Add a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situations - success, failure, or cancellation. + *

+ * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(final Function1 f) { + onCompleteCallbacks.add(new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(context); + } + }); + return this; + } + + /** + * Add a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. + * + * Deprecated: use addTaskCompletionListener + * + * @param f Callback function. + */ + @Deprecated + public void addOnCompleteCallback(final Function0 f) { + onCompleteCallbacks.add(new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(); + } + }); + } + + /** + * ::Internal API:: + * Marks the task as completed and triggers the listeners. + */ + public void markTaskCompleted() throws TaskCompletionListenerException { + completed = true; + List errorMsgs = new ArrayList(2); + // Process complete callbacks in the reverse order of registration + List revlist = + new ArrayList(onCompleteCallbacks); + Collections.reverse(revlist); + for (TaskCompletionListener tcl: revlist) { + try { + tcl.onTaskCompletion(this); + } catch (Throwable e) { + errorMsgs.add(e.getMessage()); + } + } + + if (!errorMsgs.isEmpty()) { + throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); + } + } + + /** + * ::Internal API:: + * Marks the task for interruption, i.e. cancellation. + */ + public void markInterrupted() { + interrupted = true; + } + + @Deprecated + /** Deprecated: use getStageId() */ + public int stageId() { + return stageId; + } + + @Deprecated + /** Deprecated: use getPartitionId() */ + public int partitionId() { + return partitionId; + } + + @Deprecated + /** Deprecated: use getAttemptId() */ + public long attemptId() { + return attemptId; + } + + @Deprecated + /** Deprecated: use getRunningLocally() */ + public boolean runningLocally() { + return runningLocally; + } + + public boolean getRunningLocally() { + return runningLocally; + } + + public int getStageId() { + return stageId; + } + + public int getPartitionId() { + return partitionId; + } + + public long getAttemptId() { + return attemptId; + } + + /** ::Internal API:: */ + public TaskMetrics taskMetrics() { + return taskMetrics; + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala deleted file mode 100644 index 51b3e4d5e0936..0000000000000 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} - - -/** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - * @param taskMetrics performance metrics of the task - */ -@DeveloperApi -class TaskContext( - val stageId: Int, - val partitionId: Int, - val attemptId: Long, - val runningLocally: Boolean = false, - private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends Serializable with Logging { - - @deprecated("use partitionId", "0.8.1") - def splitId = partitionId - - // List of callback functions to execute when the task completes. - @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] - - // Whether the corresponding task has been killed. - @volatile private var interrupted: Boolean = false - - // Whether the task has completed. - @volatile private var completed: Boolean = false - - /** Checks whether the task has completed. */ - def isCompleted: Boolean = completed - - /** Checks whether the task has been killed. */ - def isInterrupted: Boolean = interrupted - - // TODO: Also track whether the task has completed successfully or with exception. - - /** - * Add a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { - onCompleteCallbacks += listener - this - } - - /** - * Add a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - def addTaskCompletionListener(f: TaskContext => Unit): this.type = { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f(context) - } - this - } - - /** - * Add a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * @param f Callback function. - */ - @deprecated("use addTaskCompletionListener", "1.1.0") - def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f() - } - } - - /** Marks the task as completed and triggers the listeners. */ - private[spark] def markTaskCompleted(): Unit = { - completed = true - val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { listener => - try { - listener.onTaskCompletion(this) - } catch { - case e: Throwable => - errorMsgs += e.getMessage - logError("Error in TaskCompletionListener", e) - } - } - if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs) - } - } - - /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(): Unit = { - interrupted = true - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0e90caa5c9ca7..ba712c9d7776f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag]( * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ @DeveloperApi + @deprecated("use TaskContext.get", "1.2.0") def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b2774dfc47553..32cf29ed140e6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -634,12 +634,14 @@ class DAGScheduler( val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = - new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true) + new TaskContext(job.finalStage.id, job.partitions(0), 0, true) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() + TaskContext.remove() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 6aa0cca06878d..bf73f6f7bd0e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -45,7 +45,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + context = new TaskContext(stageId, partitionId, attemptId, false) + TaskContext.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { @@ -92,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + TaskContext.remove() + } } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b8c23d524e00b..4a078435447e5 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -776,7 +776,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics()); + TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 90dcadcffd091..d735010d7c9d5 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = true) + val context = new TaskContext(0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } From a3feaf04dc35069b80233fe7cccd62fc3072fc1f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 26 Sep 2014 21:44:10 -0700 Subject: [PATCH 16/50] Close #2194. From e976ca236f3c5578d8d7639b788774b1053b65f7 Mon Sep 17 00:00:00 2001 From: Sarah Gerweck Date: Fri, 26 Sep 2014 22:21:50 -0700 Subject: [PATCH 17/50] Slaves file is now a template. Change 0dc868e removed the `conf/slaves` file and made it a template like most of the other configuration files. This means you can no longer run `make-distribution.sh` unless you manually create a slaves file to be statically bundled in your distribution, which seems at odds with making it a template file. Author: Sarah Gerweck Closes #2549 from sarahgerweck/noMoreSlaves and squashes the following commits: d11d99a [Sarah Gerweck] Slaves file is now a template. --- make-distribution.sh | 1 - 1 file changed, 1 deletion(-) diff --git a/make-distribution.sh b/make-distribution.sh index 884659954a491..0bc839e1dbe4d 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -201,7 +201,6 @@ fi # Copy other things mkdir "$DISTDIR"/conf cp "$FWDIR"/conf/*.template "$DISTDIR"/conf -cp "$FWDIR"/conf/slaves "$DISTDIR"/conf cp "$FWDIR/README.md" "$DISTDIR" cp -r "$FWDIR/bin" "$DISTDIR" cp -r "$FWDIR/python" "$DISTDIR" From 0cdcdd2c9df98fb64d9d16ebace992fbba9c16b4 Mon Sep 17 00:00:00 2001 From: wangfei Date: Fri, 26 Sep 2014 22:23:49 -0700 Subject: [PATCH 18/50] [Build]remove spark-staging-1030 Since 1.1.0 has published, remove spark-staging-1030. Author: wangfei Closes #2532 from scwf/patch-2 and squashes the following commits: bc9e00b [wangfei] remove spark-staging-1030 --- pom.xml | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/pom.xml b/pom.xml index f3de097b9cb32..70cb9729ff6d3 100644 --- a/pom.xml +++ b/pom.xml @@ -222,18 +222,6 @@ false - - - spark-staging-1030 - Spark 1.1.0 Staging (1030) - https://repository.apache.org/content/repositories/orgapachespark-1030/ - - true - - - false - - From f0eea76d941c487763febbd9162600f89cedbd5c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 26 Sep 2014 22:24:34 -0700 Subject: [PATCH 19/50] [SQL][DOCS] Clarify that the server is for JDBC and ODBC Author: Michael Armbrust Closes #2527 from marmbrus/patch-1 and squashes the following commits: a0f9f1c [Michael Armbrust] [SQL][DOCS] Clarify that the server is for JDBC and ODBC --- docs/sql-programming-guide.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index c1f80544bf0af..65249808fae3e 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -872,12 +872,12 @@ that these options will be deprecated in future release as more optimizations ar Spark SQL also supports interfaces for running SQL queries directly without the need to write any code. -## Running the Thrift JDBC server +## Running the Thrift JDBC/ODBC server -The Thrift JDBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) +The Thrift JDBC/ODBC server implemented here corresponds to the [`HiveServer2`](https://cwiki.apache.org/confluence/display/Hive/Setting+Up+HiveServer2) in Hive 0.12. You can test the JDBC server with the beeline script that comes with either Spark or Hive 0.12. -To start the JDBC server, run the following in the Spark directory: +To start the JDBC/ODBC server, run the following in the Spark directory: ./sbin/start-thriftserver.sh @@ -906,11 +906,11 @@ or system properties: ``` {% endhighlight %} -Now you can use beeline to test the Thrift JDBC server: +Now you can use beeline to test the Thrift JDBC/ODBC server: ./bin/beeline -Connect to the JDBC server in beeline with: +Connect to the JDBC/ODBC server in beeline with: beeline> !connect jdbc:hive2://localhost:10000 From d8a9d1d442dd5612f82edaf2a780579c4d43dcfd Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 26 Sep 2014 22:30:12 -0700 Subject: [PATCH 20/50] [SPARK-3675][SQL] Allow starting a JDBC server on an existing context Author: Michael Armbrust Closes #2515 from marmbrus/jdbcExistingContext and squashes the following commits: 7866fad [Michael Armbrust] Allows starting a JDBC server on an existing context. --- .../sql/hive/thriftserver/HiveThriftServer2.scala | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index cadf7aaf42157..3d468d804622c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -26,6 +26,7 @@ import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} import org.apache.spark.Logging +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ @@ -33,9 +34,21 @@ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a * `HiveThriftServer2` thrift server. */ -private[hive] object HiveThriftServer2 extends Logging { +object HiveThriftServer2 extends Logging { var LOG = LogFactory.getLog(classOf[HiveServer2]) + /** + * :: DeveloperApi :: + * Starts a new thrift server with the given context. + */ + @DeveloperApi + def startWithContext(sqlContext: HiveContext): Unit = { + val server = new HiveThriftServer2(sqlContext) + server.init(sqlContext.hiveconf) + server.start() + } + + def main(args: Array[String]) { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") From 9e8ced7847d84d63f0da08b15623d558a2407583 Mon Sep 17 00:00:00 2001 From: Jeff Steinmetz Date: Fri, 26 Sep 2014 23:00:40 -0700 Subject: [PATCH 21/50] stop, start and destroy require the EC2_REGION i.e ./spark-ec2 --region=us-west-1 stop yourclustername Author: Jeff Steinmetz Closes #2473 from jeffsteinmetz/master and squashes the following commits: 7491f2c [Jeff Steinmetz] fix case in EC2 cluster setup documentation bd3d777 [Jeff Steinmetz] standardized ec2 documenation to use sample args 2bf4a57 [Jeff Steinmetz] standardized ec2 documenation to use sample args 68d8372 [Jeff Steinmetz] standardized ec2 documenation to use sample args d2ab6e2 [Jeff Steinmetz] standardized ec2 documenation to use sample args 520e6dc [Jeff Steinmetz] standardized ec2 documenation to use sample args 37fc876 [Jeff Steinmetz] stop, start and destroy require the EC2_REGION --- docs/ec2-scripts.md | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md index b2ca6a9b48f32..530798f2b8022 100644 --- a/docs/ec2-scripts.md +++ b/docs/ec2-scripts.md @@ -48,6 +48,15 @@ by looking for the "Name" tag of the instance in the Amazon EC2 Console. key pair, `` is the number of slave nodes to launch (try 1 at first), and `` is the name to give to your cluster. + + For example: + + ```bash + export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU +export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 +./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --spark-version=1.1.0 launch my-spark-cluster + ``` + - After everything launches, check that the cluster scheduler is up and sees all the slaves by going to its web UI, which will be printed at the end of the script (typically `http://:8080`). @@ -55,27 +64,27 @@ by looking for the "Name" tag of the instance in the Amazon EC2 Console. You can also run `./spark-ec2 --help` to see more usage options. The following options are worth pointing out: -- `--instance-type=` can be used to specify an EC2 +- `--instance-type=` can be used to specify an EC2 instance type to use. For now, the script only supports 64-bit instance types, and the default type is `m1.large` (which has 2 cores and 7.5 GB RAM). Refer to the Amazon pages about [EC2 instance types](http://aws.amazon.com/ec2/instance-types) and [EC2 pricing](http://aws.amazon.com/ec2/#pricing) for information about other instance types. -- `--region=` specifies an EC2 region in which to launch +- `--region=` specifies an EC2 region in which to launch instances. The default region is `us-east-1`. -- `--zone=` can be used to specify an EC2 availability zone +- `--zone=` can be used to specify an EC2 availability zone to launch instances in. Sometimes, you will get an error because there is not enough capacity in one zone, and you should try to launch in another. -- `--ebs-vol-size=GB` will attach an EBS volume with a given amount +- `--ebs-vol-size=` will attach an EBS volume with a given amount of space to each node so that you can have a persistent HDFS cluster on your nodes across cluster restarts (see below). -- `--spot-price=PRICE` will launch the worker nodes as +- `--spot-price=` will launch the worker nodes as [Spot Instances](http://aws.amazon.com/ec2/spot-instances/), bidding for the given maximum price (in dollars). -- `--spark-version=VERSION` will pre-load the cluster with the - specified version of Spark. VERSION can be a version number +- `--spark-version=` will pre-load the cluster with the + specified version of Spark. The `` can be a version number (e.g. "0.7.3") or a specific git hash. By default, a recent version will be used. - If one of your launches fails due to e.g. not having the right @@ -137,11 +146,11 @@ cost you any EC2 cycles, but ***will*** continue to cost money for EBS storage. - To stop one of your clusters, go into the `ec2` directory and run -`./spark-ec2 stop `. +`./spark-ec2 --region= stop `. - To restart it later, run -`./spark-ec2 -i start `. +`./spark-ec2 -i --region= start `. - To ultimately destroy the cluster and stop consuming EBS space, run -`./spark-ec2 destroy ` as described in the previous +`./spark-ec2 --region= destroy ` as described in the previous section. # Limitations From 2d972fd84ac54a89e416442508a6d4eaeff452c1 Mon Sep 17 00:00:00 2001 From: Erik Erlandson Date: Fri, 26 Sep 2014 23:15:10 -0700 Subject: [PATCH 22/50] [SPARK-1021] Defer the data-driven computation of partition bounds in so... ...rtByKey() until evaluation. Author: Erik Erlandson Closes #1689 from erikerlandson/spark-1021-pr and squashes the following commits: 50b6da6 [Erik Erlandson] use standard getIteratorSize in countAsync 4e334a9 [Erik Erlandson] exception mystery fixed by fixing bug in ComplexFutureAction b88b5d4 [Erik Erlandson] tweak async actions to use ComplexFutureAction[T] so they handle RangePartitioner sampling job properly b2b20e8 [Erik Erlandson] Fix bug in exception passing with ComplexFutureAction[T] ca8913e [Erik Erlandson] RangePartition sampling job -> FutureAction 7143f97 [Erik Erlandson] [SPARK-1021] modify range bounds variable to be thread safe ac67195 [Erik Erlandson] [SPARK-1021] Defer the data-driven computation of partition bounds in sortByKey() until evaluation. --- .../scala/org/apache/spark/FutureAction.scala | 7 +- .../scala/org/apache/spark/Partitioner.scala | 29 +++++++-- .../apache/spark/rdd/AsyncRDDActions.scala | 64 +++++++++++-------- 3 files changed, 66 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 75ea535f2f57b..c277c3a47d421 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -208,7 +208,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { processPartition: Iterator[T] => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit, - resultFunc: => R) { + resultFunc: => R): R = { // If the action hasn't been cancelled yet, submit the job. The check and the submitJob // command need to be in an atomic block. val job = this.synchronized { @@ -223,7 +223,10 @@ class ComplexFutureAction[T] extends FutureAction[T] { // cancel the job and stop the execution. This is not in a synchronized block because // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. try { - Await.ready(job, Duration.Inf) + Await.ready(job, Duration.Inf).value.get match { + case scala.util.Failure(e) => throw e + case scala.util.Success(v) => v + } } catch { case e: InterruptedException => job.cancel() diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 37053bb6f37ad..d40b152d221c5 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -29,6 +29,10 @@ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} +import org.apache.spark.SparkContext.rddToAsyncRDDActions +import scala.concurrent.Await +import scala.concurrent.duration.Duration + /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. * Maps each key to a partition ID, from 0 to `numPartitions - 1`. @@ -113,8 +117,12 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions - private var rangeBounds: Array[K] = { - if (partitions <= 1) { + @volatile private var valRB: Array[K] = null + + private def rangeBounds: Array[K] = this.synchronized { + if (valRB != null) return valRB + + valRB = if (partitions <= 1) { Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. @@ -152,6 +160,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( RangePartitioner.determineBounds(candidates, partitions) } } + + valRB } def numPartitions = rangeBounds.length + 1 @@ -222,7 +232,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream) { + private def readObject(in: ObjectInputStream): Unit = this.synchronized { + if (valRB != null) return val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => in.defaultReadObject() @@ -234,7 +245,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( val ser = sfactory.newInstance() Utils.deserializeViaNestedStream(in, ser) { ds => implicit val classTag = ds.readObject[ClassTag[Array[K]]]() - rangeBounds = ds.readObject[Array[K]]() + valRB = ds.readObject[Array[K]]() } } } @@ -254,12 +265,18 @@ private[spark] object RangePartitioner { sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object - val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => + // use collectAsync here to run this job as a future, which is cancellable + val sketchFuture = rdd.mapPartitionsWithIndex { (idx, iter) => val seed = byteswap32(idx ^ (shift << 16)) val (sample, n) = SamplingUtils.reservoirSampleAndCount( iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) - }.collect() + }.collectAsync() + // We do need the future's value to continue any further + val sketched = Await.ready(sketchFuture, Duration.Inf).value.get match { + case scala.util.Success(v) => v.toArray + case scala.util.Failure(e) => throw e + } val numItems = sketched.map(_._2.toLong).sum (numItems, sketched) } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index b62f3fbdc4a15..7a68b3afa8158 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -23,6 +23,7 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag +import org.apache.spark.util.Utils import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} import org.apache.spark.annotation.Experimental @@ -38,29 +39,30 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for counting the number of elements in the RDD. */ def countAsync(): FutureAction[Long] = { - val totalCount = new AtomicLong - self.context.submitJob( - self, - (iter: Iterator[T]) => { - var result = 0L - while (iter.hasNext) { - result += 1L - iter.next() - } - result - }, - Range(0, self.partitions.size), - (index: Int, data: Long) => totalCount.addAndGet(data), - totalCount.get()) + val f = new ComplexFutureAction[Long] + f.run { + val totalCount = new AtomicLong + f.runJob(self, + (iter: Iterator[T]) => Utils.getIteratorSize(iter), + Range(0, self.partitions.size), + (index: Int, data: Long) => totalCount.addAndGet(data), + totalCount.get()) + } } /** * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val results = new Array[Array[T]](self.partitions.size) - self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), - (index, data) => results(index) = data, results.flatten.toSeq) + val f = new ComplexFutureAction[Seq[T]] + f.run { + val results = new Array[Array[T]](self.partitions.size) + f.runJob(self, + (iter: Iterator[T]) => iter.toArray, + Range(0, self.partitions.size), + (index: Int, data: Array[T]) => results(index) = data, + results.flatten.toSeq) + } } /** @@ -104,24 +106,34 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } results.toSeq } - - f } /** * Applies a function f to all elements of this RDD. */ - def foreachAsync(f: T => Unit): FutureAction[Unit] = { - val cleanF = self.context.clean(f) - self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), - (index, data) => Unit, Unit) + def foreachAsync(expr: T => Unit): FutureAction[Unit] = { + val f = new ComplexFutureAction[Unit] + val exprClean = self.context.clean(expr) + f.run { + f.runJob(self, + (iter: Iterator[T]) => iter.foreach(exprClean), + Range(0, self.partitions.size), + (index: Int, data: Unit) => Unit, + Unit) + } } /** * Applies a function f to each partition of this RDD. */ - def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { - self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), - (index, data) => Unit, Unit) + def foreachPartitionAsync(expr: Iterator[T] => Unit): FutureAction[Unit] = { + val f = new ComplexFutureAction[Unit] + f.run { + f.runJob(self, + expr, + Range(0, self.partitions.size), + (index: Int, data: Unit) => Unit, + Unit) + } } } From 436a7730b6e7067f74b3739a3a412490003f7c4c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 27 Sep 2014 00:57:26 -0700 Subject: [PATCH 23/50] Minor cleanup to tighten visibility and remove compilation warning. Author: Reynold Xin Closes #2555 from rxin/cleanup and squashes the following commits: 6add199 [Reynold Xin] Minor cleanup to tighten visibility and remove compilation warning. --- .../input/WholeTextFileRecordReader.scala | 24 +++++----- .../apache/spark/metrics/MetricsSystem.scala | 28 ++++++----- .../spark/metrics/MetricsSystemSuite.scala | 33 +++++++------ .../streaming/StreamingContextSuite.scala | 47 ++++++++++--------- 4 files changed, 70 insertions(+), 62 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index c3dabd2e79995..3564ab2e2a162 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -36,33 +36,31 @@ private[spark] class WholeTextFileRecordReader( index: Integer) extends RecordReader[String, String] { - private val path = split.getPath(index) - private val fs = path.getFileSystem(context.getConfiguration) + private[this] val path = split.getPath(index) + private[this] val fs = path.getFileSystem(context.getConfiguration) // True means the current file has been processed, then skip it. - private var processed = false + private[this] var processed = false - private val key = path.toString - private var value: String = null + private[this] val key = path.toString + private[this] var value: String = null - override def initialize(split: InputSplit, context: TaskAttemptContext) = {} + override def initialize(split: InputSplit, context: TaskAttemptContext): Unit = {} - override def close() = {} + override def close(): Unit = {} - override def getProgress = if (processed) 1.0f else 0.0f + override def getProgress: Float = if (processed) 1.0f else 0.0f - override def getCurrentKey = key + override def getCurrentKey: String = key - override def getCurrentValue = value + override def getCurrentValue: String = value - override def nextKeyValue = { + override def nextKeyValue(): Boolean = { if (!processed) { val fileIn = fs.open(path) val innerBuffer = ByteStreams.toByteArray(fileIn) - value = new Text(innerBuffer).toString Closeables.close(fileIn, false) - processed = true true } else { diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 6ef817d0e587e..fd316a89a1a10 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -63,15 +63,18 @@ import org.apache.spark.metrics.source.Source * * [options] is the specific property of this source or sink. */ -private[spark] class MetricsSystem private (val instance: String, - conf: SparkConf, securityMgr: SecurityManager) extends Logging { +private[spark] class MetricsSystem private ( + val instance: String, + conf: SparkConf, + securityMgr: SecurityManager) + extends Logging { - val confFile = conf.get("spark.metrics.conf", null) - val metricsConfig = new MetricsConfig(Option(confFile)) + private[this] val confFile = conf.get("spark.metrics.conf", null) + private[this] val metricsConfig = new MetricsConfig(Option(confFile)) - val sinks = new mutable.ArrayBuffer[Sink] - val sources = new mutable.ArrayBuffer[Source] - val registry = new MetricRegistry() + private val sinks = new mutable.ArrayBuffer[Sink] + private val sources = new mutable.ArrayBuffer[Source] + private val registry = new MetricRegistry() // Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui private var metricsServlet: Option[MetricsServlet] = None @@ -91,7 +94,7 @@ private[spark] class MetricsSystem private (val instance: String, sinks.foreach(_.stop) } - def report(): Unit = { + def report() { sinks.foreach(_.report()) } @@ -155,8 +158,8 @@ private[spark] object MetricsSystem { val SINK_REGEX = "^sink\\.(.+)\\.(.+)".r val SOURCE_REGEX = "^source\\.(.+)\\.(.+)".r - val MINIMAL_POLL_UNIT = TimeUnit.SECONDS - val MINIMAL_POLL_PERIOD = 1 + private[this] val MINIMAL_POLL_UNIT = TimeUnit.SECONDS + private[this] val MINIMAL_POLL_PERIOD = 1 def checkMinimalPollingPeriod(pollUnit: TimeUnit, pollPeriod: Int) { val period = MINIMAL_POLL_UNIT.convert(pollPeriod, pollUnit) @@ -166,7 +169,8 @@ private[spark] object MetricsSystem { } } - def createMetricsSystem(instance: String, conf: SparkConf, - securityMgr: SecurityManager): MetricsSystem = + def createMetricsSystem( + instance: String, conf: SparkConf, securityMgr: SecurityManager): MetricsSystem = { new MetricsSystem(instance, conf, securityMgr) + } } diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 96a5a1231813e..e42b181194727 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,42 +17,47 @@ package org.apache.spark.metrics -import org.scalatest.{BeforeAndAfter, FunSuite} +import org.apache.spark.metrics.source.Source +import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester} + import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.MasterSource -class MetricsSystemSuite extends FunSuite with BeforeAndAfter { +import scala.collection.mutable.ArrayBuffer + + +class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ var conf: SparkConf = null var securityMgr: SecurityManager = null before { - filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile() + filePath = getClass.getClassLoader.getResource("test_metrics_system.properties").getFile conf = new SparkConf(false).set("spark.metrics.conf", filePath) securityMgr = new SecurityManager(conf) } test("MetricsSystem with default config") { val metricsSystem = MetricsSystem.createMetricsSystem("default", conf, securityMgr) - val sources = metricsSystem.sources - val sinks = metricsSystem.sinks + val sources = PrivateMethod[ArrayBuffer[Source]]('sources) + val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) - assert(sources.length === 0) - assert(sinks.length === 0) - assert(!metricsSystem.getServletHandlers.isEmpty) + assert(metricsSystem.invokePrivate(sources()).length === 0) + assert(metricsSystem.invokePrivate(sinks()).length === 0) + assert(metricsSystem.getServletHandlers.nonEmpty) } test("MetricsSystem with sources add") { val metricsSystem = MetricsSystem.createMetricsSystem("test", conf, securityMgr) - val sources = metricsSystem.sources - val sinks = metricsSystem.sinks + val sources = PrivateMethod[ArrayBuffer[Source]]('sources) + val sinks = PrivateMethod[ArrayBuffer[Source]]('sinks) - assert(sources.length === 0) - assert(sinks.length === 1) - assert(!metricsSystem.getServletHandlers.isEmpty) + assert(metricsSystem.invokePrivate(sources()).length === 0) + assert(metricsSystem.invokePrivate(sinks()).length === 1) + assert(metricsSystem.getServletHandlers.nonEmpty) val source = new MasterSource(null) metricsSystem.registerSource(source) - assert(sources.length === 1) + assert(metricsSystem.invokePrivate(sources()).length === 1) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index ebf83748ffa28..655cec1573f58 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -19,18 +19,18 @@ package org.apache.spark.streaming import java.util.concurrent.atomic.AtomicInteger -import scala.language.postfixOps +import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} +import org.scalatest.concurrent.Timeouts +import org.scalatest.concurrent.Eventually._ +import org.scalatest.exceptions.TestFailedDueToTimeoutException +import org.scalatest.time.SpanSugar._ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.util.Utils -import org.scalatest.{Assertions, BeforeAndAfter, FunSuite} -import org.scalatest.concurrent.Timeouts -import org.scalatest.concurrent.Eventually._ -import org.scalatest.exceptions.TestFailedDueToTimeoutException -import org.scalatest.time.SpanSugar._ + class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts with Logging { @@ -68,7 +68,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("from no conf + spark home + env") { ssc = new StreamingContext(master, appName, batchDuration, sparkHome, Nil, Map(envPair)) - assert(ssc.conf.getExecutorEnv.exists(_ == envPair)) + assert(ssc.conf.getExecutorEnv.contains(envPair)) } test("from conf with settings") { @@ -94,7 +94,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) myConf.set("spark.cleaner.ttl", "10") val ssc1 = new StreamingContext(myConf, batchDuration) - addInputStream(ssc1).register + addInputStream(ssc1).register() ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) assert(cp.sparkConfPairs.toMap.getOrElse("spark.cleaner.ttl", "-1") === "10") @@ -107,7 +107,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("start and stop state check") { ssc = new StreamingContext(master, appName, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() assert(ssc.state === ssc.StreamingContextState.Initialized) ssc.start() @@ -118,7 +118,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("start multiple times") { ssc = new StreamingContext(master, appName, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() ssc.start() intercept[SparkException] { ssc.start() @@ -127,7 +127,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop multiple times") { ssc = new StreamingContext(master, appName, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() ssc.start() ssc.stop() ssc.stop() @@ -135,7 +135,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop before start and start after stop") { ssc = new StreamingContext(master, appName, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() ssc.stop() // stop before start should not throw exception ssc.start() ssc.stop() @@ -147,12 +147,12 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("stop only streaming context") { ssc = new StreamingContext(master, appName, batchDuration) sc = ssc.sparkContext - addInputStream(ssc).register + addInputStream(ssc).register() ssc.start() - ssc.stop(false) + ssc.stop(stopSparkContext = false) assert(sc.makeRDD(1 to 100).collect().size === 100) ssc = new StreamingContext(sc, batchDuration) - addInputStream(ssc).register + addInputStream(ssc).register() ssc.start() ssc.stop() } @@ -167,11 +167,11 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w var runningCount = 0 TestReceiver.counter.set(1) val input = ssc.receiverStream(new TestReceiver) - input.count.foreachRDD(rdd => { + input.count().foreachRDD { rdd => val count = rdd.first() runningCount += count.toInt logInfo("Count = " + count + ", Running count = " + runningCount) - }) + } ssc.start() ssc.awaitTermination(500) ssc.stop(stopSparkContext = false, stopGracefully = true) @@ -191,7 +191,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("awaitTermination") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) - inputStream.map(x => x).register + inputStream.map(x => x).register() // test whether start() blocks indefinitely or not failAfter(2000 millis) { @@ -215,7 +215,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w // test whether wait exits if context is stopped failAfter(10000 millis) { // 10 seconds because spark takes a long time to shutdown new Thread() { - override def run { + override def run() { Thread.sleep(500) ssc.stop() } @@ -239,8 +239,9 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("awaitTermination with error in task") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) - inputStream.map(x => { throw new TestException("error in map task"); x}) - .foreachRDD(_.count) + inputStream + .map { x => throw new TestException("error in map task"); x } + .foreachRDD(_.count()) val exception = intercept[Exception] { ssc.start() @@ -252,7 +253,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w test("awaitTermination with error in job generation") { ssc = new StreamingContext(master, appName, batchDuration) val inputStream = addInputStream(ssc) - inputStream.transform(rdd => { throw new TestException("error in transform"); rdd }).register + inputStream.transform { rdd => throw new TestException("error in transform"); rdd }.register() val exception = intercept[TestException] { ssc.start() ssc.awaitTermination(5000) @@ -265,7 +266,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w } def addInputStream(s: StreamingContext): DStream[Int] = { - val input = (1 to 100).map(i => (1 to i)) + val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) inputStream } From 66107f46f374f83729cd79ab260eb59fa123c041 Mon Sep 17 00:00:00 2001 From: CrazyJvm Date: Sat, 27 Sep 2014 09:41:04 -0700 Subject: [PATCH 24/50] Docs : use "--total-executor-cores" rather than "--cores" after spark-shell Author: CrazyJvm Closes #2540 from CrazyJvm/standalone-core and squashes the following commits: 66d9fc6 [CrazyJvm] use "--total-executor-cores" rather than "--cores" after spark-shell --- docs/spark-standalone.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 58103fab20819..a3028aa86dc45 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -247,7 +247,7 @@ To run an interactive Spark shell against the cluster, run the following command ./bin/spark-shell --master spark://IP:PORT -You can also pass an option `--cores ` to control the number of cores that spark-shell uses on the cluster. +You can also pass an option `--total-executor-cores ` to control the number of cores that spark-shell uses on the cluster. # Launching Compiled Spark Applications From 0800881051df8029afb22a4ec17970e316a85855 Mon Sep 17 00:00:00 2001 From: w00228970 Date: Sat, 27 Sep 2014 12:06:06 -0700 Subject: [PATCH 25/50] [SPARK-3676][SQL] Fix hive test suite failure due to diffs in JDK 1.6/1.7 This is a bug in JDK6: http://bugs.java.com/bugdatabase/view_bug.do?bug_id=4428022 this is because jdk get different result to operate ```double```, ```System.out.println(1/500d)``` in different jdk get different result jdk 1.6.0(_31) ---- 0.0020 jdk 1.7.0(_05) ---- 0.002 this leads to HiveQuerySuite failed when generate golden answer in jdk 1.7 and run tests in jdk 1.6, result did not match Author: w00228970 Closes #2517 from scwf/HiveQuerySuite and squashes the following commits: 0cb5e8d [w00228970] delete golden answer of division-0 and timestamp cast #1 1df3964 [w00228970] Jdk version leads to different query output for Double, this make HiveQuerySuite failed --- .../division-0-63b19f8a22471c8ba0415c1d3bc276f7 | 1 - ...amp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 | 1 - .../spark/sql/hive/execution/HiveQuerySuite.scala | 15 +++++++++++---- 3 files changed, 11 insertions(+), 6 deletions(-) delete mode 100644 sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 delete mode 100644 sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 diff --git a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 b/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 deleted file mode 100644 index 7b7a9175114ce..0000000000000 --- a/sql/hive/src/test/resources/golden/division-0-63b19f8a22471c8ba0415c1d3bc276f7 +++ /dev/null @@ -1 +0,0 @@ -2.0 0.5 0.3333333333333333 0.002 diff --git a/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 b/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 deleted file mode 100644 index 8ebf695ba7d20..0000000000000 --- a/sql/hive/src/test/resources/golden/timestamp cast #1-0-69fc614ccea92bbe39f4decc299edcc6 +++ /dev/null @@ -1 +0,0 @@ -0.001 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2f876cafaf218..2da8a6fac3d99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -135,8 +135,12 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("div", "SELECT 1 DIV 2, 1 div 2, 1 dIv 2, 100 DIV 51, 100 DIV 49 FROM src LIMIT 1") - createQueryTest("division", - "SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1") + // Jdk version leads to different query output for double, so not use createQueryTest here + test("division") { + val res = sql("SELECT 2 / 1, 1 / 2, 1 / 3, 1 / COUNT(*) FROM src LIMIT 1").collect().head + Seq(2.0, 0.5, 0.3333333333333333, 0.002).zip(res).foreach( x => + assert(x._1 == x._2.asInstanceOf[Double])) + } createQueryTest("modulus", "SELECT 11 % 10, IF((101.1 % 100.0) BETWEEN 1.01 AND 1.11, \"true\", \"false\"), (101 / 2) % 10 FROM src LIMIT 1") @@ -306,8 +310,11 @@ class HiveQuerySuite extends HiveComparisonTest { createQueryTest("case statements WITHOUT key #4", "SELECT (CASE WHEN key > 2 THEN 3 WHEN 2 > key THEN 2 ELSE 0 END) FROM src WHERE key < 15") - createQueryTest("timestamp cast #1", - "SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") + // Jdk version leads to different query output for double, so not use createQueryTest here + test("timestamp cast #1") { + val res = sql("SELECT CAST(CAST(1 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1").collect().head + assert(0.001 == res.getDouble(0)) + } createQueryTest("timestamp cast #2", "SELECT CAST(CAST(1.2 AS TIMESTAMP) AS DOUBLE) FROM src LIMIT 1") From f0c7e19550d46f81a0a3ff272bbf66ce4bafead6 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 27 Sep 2014 12:10:16 -0700 Subject: [PATCH 26/50] [SPARK-3680][SQL] Fix bug caused by eager typing of HiveGenericUDFs Typing of UDFs should be lazy as it is often not valid to call `dataType` on an expression until after all of its children are `resolved`. Author: Michael Armbrust Closes #2525 from marmbrus/concatBug and squashes the following commits: 5b8efe7 [Michael Armbrust] fix bug with eager typing of udfs --- .../org/apache/spark/sql/hive/hiveUdfs.scala | 2 +- .../spark/sql/parquet/ParquetMetastoreSuite.scala | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 68944ed4ef21d..732e4976f6843 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -151,7 +151,7 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq override def get(): AnyRef = wrap(func()) } - val dataType: DataType = inspectorToDataType(returnInspector) + lazy val dataType: DataType = inspectorToDataType(returnInspector) override def eval(input: Row): Any = { returnInspector // Make sure initialized. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala index e380280f301c1..86adbbf3ad2d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.parquet import java.io.File +import org.apache.spark.sql.catalyst.expressions.Row import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.QueryTest @@ -142,15 +143,21 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll { test("sum") { checkAnswer( sql("SELECT SUM(intField) FROM partitioned_parquet WHERE intField IN (1,2,3) AND p = 1"), - 1 + 2 + 3 - ) + 1 + 2 + 3) + } + + test("hive udfs") { + checkAnswer( + sql("SELECT concat(stringField, stringField) FROM partitioned_parquet"), + sql("SELECT stringField FROM partitioned_parquet").map { + case Row(s: String) => Row(s + s) + }.collect().toSeq) } test("non-part select(*)") { checkAnswer( sql("SELECT COUNT(*) FROM normal_parquet"), - 10 - ) + 10) } test("conversion is working") { From 0d8cdf0ede908f6c488a075170f1563815009e29 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 27 Sep 2014 12:21:37 -0700 Subject: [PATCH 27/50] [SPARK-3681] [SQL] [PySpark] fix serialization of List and Map in SchemaRDD Currently, the schema of object in ArrayType or MapType is attached lazily, it will have better performance but introduce issues while serialization or accessing nested objects. This patch will apply schema to the objects of ArrayType or MapType immediately when accessing them, will be a little bit slower, but much robust. Author: Davies Liu Closes #2526 from davies/nested and squashes the following commits: 2399ae5 [Davies Liu] fix serialization of List and Map in SchemaRDD --- python/pyspark/sql.py | 40 +++++++++++++--------------------------- python/pyspark/tests.py | 21 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 653195ea438cf..f71d24c470dc9 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -838,43 +838,29 @@ def _create_cls(dataType): >>> obj = _create_cls(schema)(row) >>> pickle.loads(pickle.dumps(obj)) Row(a=[1], b={'key': Row(c=1, d=2.0)}) + >>> pickle.loads(pickle.dumps(obj.a)) + [1] + >>> pickle.loads(pickle.dumps(obj.b)) + {'key': Row(c=1, d=2.0)} """ if isinstance(dataType, ArrayType): cls = _create_cls(dataType.elementType) - class List(list): - - def __getitem__(self, i): - # create object with datetype - return _create_object(cls, list.__getitem__(self, i)) - - def __repr__(self): - # call collect __repr__ for nested objects - return "[%s]" % (", ".join(repr(self[i]) - for i in range(len(self)))) - - def __reduce__(self): - return list.__reduce__(self) + def List(l): + if l is None: + return + return [_create_object(cls, v) for v in l] return List elif isinstance(dataType, MapType): - vcls = _create_cls(dataType.valueType) - - class Dict(dict): - - def __getitem__(self, k): - # create object with datetype - return _create_object(vcls, dict.__getitem__(self, k)) - - def __repr__(self): - # call collect __repr__ for nested objects - return "{%s}" % (", ".join("%r: %r" % (k, self[k]) - for k in self)) + cls = _create_cls(dataType.valueType) - def __reduce__(self): - return dict.__reduce__(self) + def Dict(d): + if d is None: + return + return dict((k, _create_object(cls, v)) for k, v in d.items()) return Dict diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index d1bb2033b7a16..29df754c6fd29 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -698,6 +698,27 @@ def test_apply_schema_to_row(self): srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) self.assertEqual(10, srdd3.count()) + def test_serialize_nested_array_and_map(self): + d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] + rdd = self.sc.parallelize(d) + srdd = self.sqlCtx.inferSchema(rdd) + row = srdd.first() + self.assertEqual(1, len(row.l)) + self.assertEqual(1, row.l[0].a) + self.assertEqual("2", row.d["key"].d) + + l = srdd.map(lambda x: x.l).first() + self.assertEqual(1, len(l)) + self.assertEqual('s', l[0].b) + + d = srdd.map(lambda x: x.d).first() + self.assertEqual(1, len(d)) + self.assertEqual(1.0, d["key"].c) + + row = srdd.map(lambda x: x.d["key"]).first() + self.assertEqual(1.0, row.c) + self.assertEqual("2", row.d) + class TestIO(PySparkTestCase): From 5b922bb458e863f5be0ae68167de882743f70b86 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 27 Sep 2014 14:46:00 -0700 Subject: [PATCH 28/50] [SPARK-3543] Clean up Java TaskContext implementation. This addresses some minor issues in https://github.com/apache/spark/pull/2425 Author: Reynold Xin Closes #2557 from rxin/TaskContext and squashes the following commits: a51e5f6 [Reynold Xin] [SPARK-3543] Clean up Java TaskContext implementation. --- .../java/org/apache/spark/TaskContext.java | 33 ++++++++----------- .../apache/spark/scheduler/DAGScheduler.scala | 2 +- .../apache/spark/scheduler/ResultTask.scala | 6 +--- .../spark/scheduler/ShuffleMapTask.scala | 2 -- .../org/apache/spark/scheduler/Task.scala | 8 +++-- 5 files changed, 22 insertions(+), 29 deletions(-) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 09b8ce02bd3d8..4e6d708af0ea7 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -56,7 +56,7 @@ public class TaskContext implements Serializable { * @param taskMetrics performance metrics of the task */ @DeveloperApi - public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally, + public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally, TaskMetrics taskMetrics) { this.attemptId = attemptId; this.partitionId = partitionId; @@ -65,7 +65,6 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean this.taskMetrics = taskMetrics; } - /** * :: DeveloperApi :: * Contextual information about a task which can be read or mutated during execution. @@ -76,8 +75,7 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean * @param runningLocally whether the task is running locally in the driver JVM */ @DeveloperApi - public TaskContext(Integer stageId, Integer partitionId, Long attemptId, - Boolean runningLocally) { + public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) { this.attemptId = attemptId; this.partitionId = partitionId; this.runningLocally = runningLocally; @@ -85,7 +83,6 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, this.taskMetrics = TaskMetrics.empty(); } - /** * :: DeveloperApi :: * Contextual information about a task which can be read or mutated during execution. @@ -95,7 +92,7 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, * @param attemptId the number of attempts to execute this task */ @DeveloperApi - public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { + public TaskContext(int stageId, int partitionId, long attemptId) { this.attemptId = attemptId; this.partitionId = partitionId; this.runningLocally = false; @@ -107,9 +104,9 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { new ThreadLocal(); /** - * :: Internal API :: - * This is spark internal API, not intended to be called from user programs. - */ + * :: Internal API :: + * This is spark internal API, not intended to be called from user programs. + */ public static void setTaskContext(TaskContext tc) { taskContext.set(tc); } @@ -118,10 +115,8 @@ public static TaskContext get() { return taskContext.get(); } - /** - * :: Internal API :: - */ - public static void remove() { + /** :: Internal API :: */ + public static void unset() { taskContext.remove(); } @@ -130,22 +125,22 @@ public static void remove() { new ArrayList(); // Whether the corresponding task has been killed. - private volatile Boolean interrupted = false; + private volatile boolean interrupted = false; // Whether the task has completed. - private volatile Boolean completed = false; + private volatile boolean completed = false; /** * Checks whether the task has completed. */ - public Boolean isCompleted() { + public boolean isCompleted() { return completed; } /** * Checks whether the task has been killed. */ - public Boolean isInterrupted() { + public boolean isInterrupted() { return interrupted; } @@ -246,12 +241,12 @@ public long attemptId() { } @Deprecated - /** Deprecated: use getRunningLocally() */ + /** Deprecated: use isRunningLocally() */ public boolean runningLocally() { return runningLocally; } - public boolean getRunningLocally() { + public boolean isRunningLocally() { return runningLocally; } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 32cf29ed140e6..70c235dffff70 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -641,7 +641,7 @@ class DAGScheduler( job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() - TaskContext.remove() + TaskContext.unset() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 2ccbd8edeb028..4a9ff918afe25 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -58,11 +58,7 @@ private[spark] class ResultTask[T, U]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) metrics = Some(context.taskMetrics) - try { - func(context, rdd.iterator(partition, context)) - } finally { - context.markTaskCompleted() - } + func(context, rdd.iterator(partition, context)) } // This is only callable on the driver side. diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index a98ee118254a3..79709089c0da4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -78,8 +78,6 @@ private[spark] class ShuffleMapTask( log.debug("Could not stop writer", e) } throw e - } finally { - context.markTaskCompleted() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index bf73f6f7bd0e1..c6e47c84a0cb2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -52,7 +52,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (_killed) { kill(interruptThread = false) } - runTask(context) + try { + runTask(context) + } finally { + context.markTaskCompleted() + TaskContext.unset() + } } def runTask(context: TaskContext): T @@ -93,7 +98,6 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - TaskContext.remove() } } From 248232936e1bead7f102e59eb8faf3126c582d9d Mon Sep 17 00:00:00 2001 From: Uri Laserson Date: Sat, 27 Sep 2014 21:48:05 -0700 Subject: [PATCH 29/50] [SPARK-3389] Add Converter for ease of Parquet reading in PySpark https://issues.apache.org/jira/browse/SPARK-3389 Author: Uri Laserson Closes #2256 from laserson/SPARK-3389 and squashes the following commits: 0ed363e [Uri Laserson] PEP8'd the python file 0b4b380 [Uri Laserson] Moved converter to examples and added python example eecf4dc [Uri Laserson] [SPARK-3389] Add Converter for ease of Parquet reading in PySpark --- .../src/main/python/parquet_inputformat.py | 59 ++++++++++++++ examples/src/main/resources/full_user.avsc | 1 + examples/src/main/resources/users.parquet | Bin 0 -> 615 bytes .../pythonconverters/AvroConverters.scala | 76 +++++++++++------- 4 files changed, 106 insertions(+), 30 deletions(-) create mode 100644 examples/src/main/python/parquet_inputformat.py create mode 100644 examples/src/main/resources/full_user.avsc create mode 100644 examples/src/main/resources/users.parquet diff --git a/examples/src/main/python/parquet_inputformat.py b/examples/src/main/python/parquet_inputformat.py new file mode 100644 index 0000000000000..c9b08f878a1e6 --- /dev/null +++ b/examples/src/main/python/parquet_inputformat.py @@ -0,0 +1,59 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Read data file users.parquet in local Spark distro: + +$ cd $SPARK_HOME +$ export AVRO_PARQUET_JARS=/path/to/parquet-avro-1.5.0.jar +$ ./bin/spark-submit --driver-class-path /path/to/example/jar \\ + --jars $AVRO_PARQUET_JARS \\ + ./examples/src/main/python/parquet_inputformat.py \\ + examples/src/main/resources/users.parquet +<...lots of log output...> +{u'favorite_color': None, u'name': u'Alyssa', u'favorite_numbers': [3, 9, 15, 20]} +{u'favorite_color': u'red', u'name': u'Ben', u'favorite_numbers': []} +<...more log output...> +""" +if __name__ == "__main__": + if len(sys.argv) != 2: + print >> sys.stderr, """ + Usage: parquet_inputformat.py + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar \\ + /path/to/examples/parquet_inputformat.py + Assumes you have Parquet data stored in . + """ + exit(-1) + + path = sys.argv[1] + sc = SparkContext(appName="ParquetInputFormat") + + parquet_rdd = sc.newAPIHadoopFile( + path, + 'parquet.avro.AvroParquetInputFormat', + 'java.lang.Void', + 'org.apache.avro.generic.IndexedRecord', + valueConverter='org.apache.spark.examples.pythonconverters.IndexedRecordToJavaConverter') + output = parquet_rdd.map(lambda x: x[1]).collect() + for k in output: + print k diff --git a/examples/src/main/resources/full_user.avsc b/examples/src/main/resources/full_user.avsc new file mode 100644 index 0000000000000..04e7ba2dca4f6 --- /dev/null +++ b/examples/src/main/resources/full_user.avsc @@ -0,0 +1 @@ +{"type": "record", "namespace": "example.avro", "name": "User", "fields": [{"type": "string", "name": "name"}, {"type": ["string", "null"], "name": "favorite_color"}, {"type": {"items": "int", "type": "array"}, "name": "favorite_numbers"}]} \ No newline at end of file diff --git a/examples/src/main/resources/users.parquet b/examples/src/main/resources/users.parquet new file mode 100644 index 0000000000000000000000000000000000000000..aa527338c43a8400fd56e549cb28aa1e6a9ccccf GIT binary patch literal 615 zcmZuv%WA?v6dhv>skOF(GbAMx8A#|N4V6|9aiOIPms03PD`l!<8_27ZC>8M^`h9*v zzoIu$LutDhxO2}r_i<*1{f8z-m}1MuG6X7C5vuhRgizmG#W5>FbjJgL&hf>LqokaZ zYYC8|l;VQV0B_^Ajmr=y801Dh!>c8wcbamJ;GDv#!@-jNG^p_p=0_fP*iwYfW6TA} zaK%KL95A1oK&zONR-LnDDBOfUPeU&hkZvLEEKddt|AmVfOQ~2gWv#@7U@Jsq-O#(1 zYT$})sz~1z#S)RpJsDWAfHh39mWmYpcax0PB|U2hw9kS8^O``r{L^;V4CrMtA|s$8 zM79MYBi+!Bv%TW!8~2&EEv#v>ia701!Ka~^QJbb)!ad!5e~TkFO;bOe0ch@WZx++e zczw`hQu|ObPJ|o0(v6+txjmU@P-546O!ri1zVJLc`A@QUG#BNAXU0Mr-ol4zs2e17 jvzcs=rbSG=FL-k0i^dXO!wrK*)46qS&=)u|gfI3D{z{mb literal 0 HcmV?d00001 diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala index 1b25983a38453..a11890d6f2b1c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/AvroConverters.scala @@ -30,21 +30,28 @@ import org.apache.spark.api.python.Converter import org.apache.spark.SparkException -/** - * Implementation of [[org.apache.spark.api.python.Converter]] that converts - * an Avro Record wrapped in an AvroKey (or AvroValue) to a Java Map. It tries - * to work with all 3 Avro data mappings (Generic, Specific and Reflect). - */ -class AvroWrapperToJavaConverter extends Converter[Any, Any] { - override def convert(obj: Any): Any = { +object AvroConversionUtil extends Serializable { + def fromAvro(obj: Any, schema: Schema): Any = { if (obj == null) { return null } - obj.asInstanceOf[AvroWrapper[_]].datum() match { - case null => null - case record: IndexedRecord => unpackRecord(record) - case other => throw new SparkException( - s"Unsupported top-level Avro data type ${other.getClass.getName}") + schema.getType match { + case UNION => unpackUnion(obj, schema) + case ARRAY => unpackArray(obj, schema) + case FIXED => unpackFixed(obj, schema) + case MAP => unpackMap(obj, schema) + case BYTES => unpackBytes(obj) + case RECORD => unpackRecord(obj) + case STRING => obj.toString + case ENUM => obj.toString + case NULL => obj + case BOOLEAN => obj + case DOUBLE => obj + case FLOAT => obj + case INT => obj + case LONG => obj + case other => throw new SparkException( + s"Unknown Avro schema type ${other.getName}") } } @@ -103,28 +110,37 @@ class AvroWrapperToJavaConverter extends Converter[Any, Any] { "Unions may only consist of a concrete type and null") } } +} - def fromAvro(obj: Any, schema: Schema): Any = { +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts + * an Avro IndexedRecord (e.g., derived from AvroParquetInputFormat) to a Java Map. + */ +class IndexedRecordToJavaConverter extends Converter[IndexedRecord, JMap[String, Any]]{ + override def convert(record: IndexedRecord): JMap[String, Any] = { + if (record == null) { + return null + } + val map = new java.util.HashMap[String, Any] + AvroConversionUtil.unpackRecord(record) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts + * an Avro Record wrapped in an AvroKey (or AvroValue) to a Java Map. It tries + * to work with all 3 Avro data mappings (Generic, Specific and Reflect). + */ +class AvroWrapperToJavaConverter extends Converter[Any, Any] { + override def convert(obj: Any): Any = { if (obj == null) { return null } - schema.getType match { - case UNION => unpackUnion(obj, schema) - case ARRAY => unpackArray(obj, schema) - case FIXED => unpackFixed(obj, schema) - case MAP => unpackMap(obj, schema) - case BYTES => unpackBytes(obj) - case RECORD => unpackRecord(obj) - case STRING => obj.toString - case ENUM => obj.toString - case NULL => obj - case BOOLEAN => obj - case DOUBLE => obj - case FLOAT => obj - case INT => obj - case LONG => obj - case other => throw new SparkException( - s"Unknown Avro schema type ${other.getName}") + obj.asInstanceOf[AvroWrapper[_]].datum() match { + case null => null + case record: IndexedRecord => AvroConversionUtil.unpackRecord(record) + case other => throw new SparkException( + s"Unsupported top-level Avro data type ${other.getClass.getName}") } } } From 9966d1a8aaed3d8cfed93855959705ea3c677215 Mon Sep 17 00:00:00 2001 From: Dale Date: Sat, 27 Sep 2014 22:08:10 -0700 Subject: [PATCH 30/50] SPARK-CORE [SPARK-3651] Group common CoarseGrainedSchedulerBackend variables together from [SPARK-3651] In CoarseGrainedSchedulerBackend, we have: private val executorActor = new HashMap[String, ActorRef] private val executorAddress = new HashMap[String, Address] private val executorHost = new HashMap[String, String] private val freeCores = new HashMap[String, Int] private val totalCores = new HashMap[String, Int] We only ever put / remove stuff from these maps together. It would simplify the code if we consolidate these all into one map as we have done in JobProgressListener in https://issues.apache.org/jira/browse/SPARK-2299. Author: Dale Closes #2533 from tigerquoll/SPARK-3651 and squashes the following commits: d1be0a9 [Dale] [SPARK-3651] implemented suggested changes. Changed a reference from executorInfo to executorData to be consistent with other usages 6890663 [Dale] [SPARK-3651] implemented suggested changes 7d671cf [Dale] [SPARK-3651] Grouped variables under a ExecutorDataObject, and reference them via a map entry as they are all retrieved under the same key --- .../CoarseGrainedSchedulerBackend.scala | 68 ++++++++----------- .../scheduler/cluster/ExecutorData.scala | 38 +++++++++++ 2 files changed, 68 insertions(+), 38 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 9a0cb1c6c6ccd..59e15edc75f5a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -62,15 +62,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A val createTime = System.currentTimeMillis() class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { - override protected def log = CoarseGrainedSchedulerBackend.this.log - - private val executorActor = new HashMap[String, ActorRef] - private val executorAddress = new HashMap[String, Address] - private val executorHost = new HashMap[String, String] - private val freeCores = new HashMap[String, Int] - private val totalCores = new HashMap[String, Int] private val addressToExecutorId = new HashMap[Address, String] + private val executorDataMap = new HashMap[String, ExecutorData] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() @@ -85,16 +79,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A def receiveWithLogging = { case RegisterExecutor(executorId, hostPort, cores) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) - if (executorActor.contains(executorId)) { + if (executorDataMap.contains(executorId)) { sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) } else { logInfo("Registered executor: " + sender + " with ID " + executorId) sender ! RegisteredExecutor - executorActor(executorId) = sender - executorHost(executorId) = Utils.parseHostPort(hostPort)._1 - totalCores(executorId) = cores - freeCores(executorId) = cores - executorAddress(executorId) = sender.path.address + executorDataMap.put(executorId, new ExecutorData(sender, sender.path.address, + Utils.parseHostPort(hostPort)._1, cores, cores)) + addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) @@ -104,13 +96,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A case StatusUpdate(executorId, taskId, state, data) => scheduler.statusUpdate(taskId, state, data.value) if (TaskState.isFinished(state)) { - if (executorActor.contains(executorId)) { - freeCores(executorId) += scheduler.CPUS_PER_TASK - makeOffers(executorId) - } else { - // Ignoring the update since we don't know about the executor. - val msg = "Ignored task status update (%d state %s) from unknown executor %s with ID %s" - logWarning(msg.format(taskId, state, sender, executorId)) + executorDataMap.get(executorId) match { + case Some(executorInfo) => + executorInfo.freeCores += scheduler.CPUS_PER_TASK + makeOffers(executorId) + case None => + // Ignoring the update since we don't know about the executor. + logWarning(s"Ignored task status update ($taskId state $state) " + + "from unknown executor $sender with ID $executorId") } } @@ -118,7 +111,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A makeOffers() case KillTask(taskId, executorId, interruptThread) => - executorActor(executorId) ! KillTask(taskId, executorId, interruptThread) + executorDataMap(executorId).executorActor ! KillTask(taskId, executorId, interruptThread) case StopDriver => sender ! true @@ -126,8 +119,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A case StopExecutors => logInfo("Asking each executor to shut down") - for (executor <- executorActor.values) { - executor ! StopExecutor + for ((_, executorData) <- executorDataMap) { + executorData.executorActor ! StopExecutor } sender ! true @@ -138,6 +131,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) sender ! true + case DisassociatedEvent(_, address, _) => addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) @@ -149,13 +143,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A // Make fake resource offers on all executors def makeOffers() { launchTasks(scheduler.resourceOffers( - executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))})) + executorDataMap.map {case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores)}.toSeq)) } // Make fake resource offers on just one executor def makeOffers(executorId: String) { + val executorData = executorDataMap(executorId) launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId))))) + Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) } // Launch tasks returned by a set of resource offers @@ -179,25 +175,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A } } else { - freeCores(task.executorId) -= scheduler.CPUS_PER_TASK - executorActor(task.executorId) ! LaunchTask(new SerializableBuffer(serializedTask)) + val executorData = executorDataMap(task.executorId) + executorData.freeCores -= scheduler.CPUS_PER_TASK + executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) } } } // Remove a disconnected slave from the cluster def removeExecutor(executorId: String, reason: String) { - if (executorActor.contains(executorId)) { - logInfo("Executor " + executorId + " disconnected, so removing it") - val numCores = totalCores(executorId) - executorActor -= executorId - executorHost -= executorId - addressToExecutorId -= executorAddress(executorId) - executorAddress -= executorId - totalCores -= executorId - freeCores -= executorId - totalCoreCount.addAndGet(-numCores) - scheduler.executorLost(executorId, SlaveLost(reason)) + executorDataMap.get(executorId) match { + case Some(executorInfo) => + executorDataMap -= executorId + totalCoreCount.addAndGet(-executorInfo.totalCores) + scheduler.executorLost(executorId, SlaveLost(reason)) + case None => logError(s"Asked to remove non existant executor $executorId") } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala new file mode 100644 index 0000000000000..74a92985b6629 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import akka.actor.{Address, ActorRef} + +/** + * Grouping of data that is accessed by a CourseGrainedScheduler. This class + * is stored in a Map keyed by an executorID + * + * @param executorActor The actorRef representing this executor + * @param executorAddress The network address of this executor + * @param executorHost The hostname that this executor is running on + * @param freeCores The current number of cores available for work on the executor + * @param totalCores The total number of cores available to the executor + */ +private[cluster] class ExecutorData( + val executorActor: ActorRef, + val executorAddress: Address, + val executorHost: String , + var freeCores: Int, + val totalCores: Int +) From 66e1c40c67f40dc4a5519812bc84877751933e7a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 27 Sep 2014 22:18:02 -0700 Subject: [PATCH 31/50] Minor fix for the previous commit. --- .../scheduler/cluster/CoarseGrainedSchedulerBackend.scala | 6 +++--- .../org/apache/spark/scheduler/cluster/ExecutorData.scala | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 59e15edc75f5a..89089e7d6f8a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -142,9 +142,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A // Make fake resource offers on all executors def makeOffers() { - launchTasks(scheduler.resourceOffers( - executorDataMap.map {case (id, executorData) => - new WorkerOffer(id, executorData.executorHost, executorData.freeCores)}.toSeq)) + launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => + new WorkerOffer(id, executorData.executorHost, executorData.freeCores) + }.toSeq)) } // Make fake resource offers on just one executor diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 74a92985b6629..b71bd5783d6df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -20,10 +20,9 @@ package org.apache.spark.scheduler.cluster import akka.actor.{Address, ActorRef} /** - * Grouping of data that is accessed by a CourseGrainedScheduler. This class - * is stored in a Map keyed by an executorID + * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The actorRef representing this executor + * @param executorActor The ActorRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor From 6918012d0f4841c5422b5827879a952428ec3a62 Mon Sep 17 00:00:00 2001 From: William Benton Date: Sun, 28 Sep 2014 01:01:27 -0700 Subject: [PATCH 32/50] SPARK-3699: SQL and Hive console tasks now clean up appropriately The sbt tasks sql/console and hive/console will now `stop()` the `SparkContext` upon exit. Previously, they left an ugly stack trace when quitting. Author: William Benton Closes #2547 from willb/consoleCleanup and squashes the following commits: d5e431f [William Benton] SQL and Hive console tasks now clean up. --- project/SparkBuild.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 12ac82293df76..01a5b20e7c51d 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -221,7 +221,8 @@ object SQL { |import org.apache.spark.sql.catalyst.util._ |import org.apache.spark.sql.execution |import org.apache.spark.sql.test.TestSQLContext._ - |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin + |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, + cleanupCommands in console := "sparkContext.stop()" ) } @@ -249,7 +250,8 @@ object Hive { |import org.apache.spark.sql.execution |import org.apache.spark.sql.hive._ |import org.apache.spark.sql.hive.test.TestHive._ - |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin + |import org.apache.spark.sql.parquet.ParquetTestData""".stripMargin, + cleanupCommands in console := "sparkContext.stop()" ) } From 1f13a40ccd5a869aec62788a1e345dc24fa648c8 Mon Sep 17 00:00:00 2001 From: WangTaoTheTonic Date: Sun, 28 Sep 2014 18:30:13 -0700 Subject: [PATCH 33/50] [SPARK-3715][Docs]minor typo https://issues.apache.org/jira/browse/SPARK-3715 Author: WangTaoTheTonic Closes #2567 from WangTaoTheTonic/minortypo and squashes the following commits: 9cc3f7a [WangTaoTheTonic] minor typo --- docs/sql-programming-guide.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 65249808fae3e..818fd5ab80af8 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -732,7 +732,7 @@ Configuration of Hive is done by placing your `hive-site.xml` file in `conf/`.

When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in in the MetaStore and writing queries using HiveQL. Users who do +adds support for finding tables in the MetaStore and writing queries using HiveQL. Users who do not have an existing Hive deployment can still create a HiveContext. When not configured by the hive-site.xml, the context automatically creates `metastore_db` and `warehouse` in the current directory. @@ -753,7 +753,7 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println)
When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and -adds support for finding tables in in the MetaStore and writing queries using HiveQL. In addition to +adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be expressed in HiveQL. @@ -774,7 +774,7 @@ Row[] results = sqlContext.sql("FROM src SELECT key, value").collect();
When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and -adds support for finding tables in in the MetaStore and writing queries using HiveQL. In addition to +adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be expressed in HiveQL. From 8e874185ed9efae8a1dc6b61d56ff401d72bb087 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 28 Sep 2014 18:33:11 -0700 Subject: [PATCH 34/50] Revert "[SPARK-1021] Defer the data-driven computation of partition bounds in so..." This reverts commit 2d972fd84ac54a89e416442508a6d4eaeff452c1. The commit was hanging correlationoptimizer14. --- .../scala/org/apache/spark/FutureAction.scala | 7 +- .../scala/org/apache/spark/Partitioner.scala | 29 ++------- .../apache/spark/rdd/AsyncRDDActions.scala | 64 ++++++++----------- 3 files changed, 34 insertions(+), 66 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index c277c3a47d421..75ea535f2f57b 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -208,7 +208,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { processPartition: Iterator[T] => U, partitions: Seq[Int], resultHandler: (Int, U) => Unit, - resultFunc: => R): R = { + resultFunc: => R) { // If the action hasn't been cancelled yet, submit the job. The check and the submitJob // command need to be in an atomic block. val job = this.synchronized { @@ -223,10 +223,7 @@ class ComplexFutureAction[T] extends FutureAction[T] { // cancel the job and stop the execution. This is not in a synchronized block because // Await.ready eventually waits on the monitor in FutureJob.jobWaiter. try { - Await.ready(job, Duration.Inf).value.get match { - case scala.util.Failure(e) => throw e - case scala.util.Success(v) => v - } + Await.ready(job, Duration.Inf) } catch { case e: InterruptedException => job.cancel() diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index d40b152d221c5..37053bb6f37ad 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -29,10 +29,6 @@ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} -import org.apache.spark.SparkContext.rddToAsyncRDDActions -import scala.concurrent.Await -import scala.concurrent.duration.Duration - /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. * Maps each key to a partition ID, from 0 to `numPartitions - 1`. @@ -117,12 +113,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( private var ordering = implicitly[Ordering[K]] // An array of upper bounds for the first (partitions - 1) partitions - @volatile private var valRB: Array[K] = null - - private def rangeBounds: Array[K] = this.synchronized { - if (valRB != null) return valRB - - valRB = if (partitions <= 1) { + private var rangeBounds: Array[K] = { + if (partitions <= 1) { Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. @@ -160,8 +152,6 @@ class RangePartitioner[K : Ordering : ClassTag, V]( RangePartitioner.determineBounds(candidates, partitions) } } - - valRB } def numPartitions = rangeBounds.length + 1 @@ -232,8 +222,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } @throws(classOf[IOException]) - private def readObject(in: ObjectInputStream): Unit = this.synchronized { - if (valRB != null) return + private def readObject(in: ObjectInputStream) { val sfactory = SparkEnv.get.serializer sfactory match { case js: JavaSerializer => in.defaultReadObject() @@ -245,7 +234,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( val ser = sfactory.newInstance() Utils.deserializeViaNestedStream(in, ser) { ds => implicit val classTag = ds.readObject[ClassTag[Array[K]]]() - valRB = ds.readObject[Array[K]]() + rangeBounds = ds.readObject[Array[K]]() } } } @@ -265,18 +254,12 @@ private[spark] object RangePartitioner { sampleSizePerPartition: Int): (Long, Array[(Int, Int, Array[K])]) = { val shift = rdd.id // val classTagK = classTag[K] // to avoid serializing the entire partitioner object - // use collectAsync here to run this job as a future, which is cancellable - val sketchFuture = rdd.mapPartitionsWithIndex { (idx, iter) => + val sketched = rdd.mapPartitionsWithIndex { (idx, iter) => val seed = byteswap32(idx ^ (shift << 16)) val (sample, n) = SamplingUtils.reservoirSampleAndCount( iter, sampleSizePerPartition, seed) Iterator((idx, n, sample)) - }.collectAsync() - // We do need the future's value to continue any further - val sketched = Await.ready(sketchFuture, Duration.Inf).value.get match { - case scala.util.Success(v) => v.toArray - case scala.util.Failure(e) => throw e - } + }.collect() val numItems = sketched.map(_._2.toLong).sum (numItems, sketched) } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 7a68b3afa8158..b62f3fbdc4a15 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -23,7 +23,6 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.reflect.ClassTag -import org.apache.spark.util.Utils import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} import org.apache.spark.annotation.Experimental @@ -39,30 +38,29 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi * Returns a future for counting the number of elements in the RDD. */ def countAsync(): FutureAction[Long] = { - val f = new ComplexFutureAction[Long] - f.run { - val totalCount = new AtomicLong - f.runJob(self, - (iter: Iterator[T]) => Utils.getIteratorSize(iter), - Range(0, self.partitions.size), - (index: Int, data: Long) => totalCount.addAndGet(data), - totalCount.get()) - } + val totalCount = new AtomicLong + self.context.submitJob( + self, + (iter: Iterator[T]) => { + var result = 0L + while (iter.hasNext) { + result += 1L + iter.next() + } + result + }, + Range(0, self.partitions.size), + (index: Int, data: Long) => totalCount.addAndGet(data), + totalCount.get()) } /** * Returns a future for retrieving all elements of this RDD. */ def collectAsync(): FutureAction[Seq[T]] = { - val f = new ComplexFutureAction[Seq[T]] - f.run { - val results = new Array[Array[T]](self.partitions.size) - f.runJob(self, - (iter: Iterator[T]) => iter.toArray, - Range(0, self.partitions.size), - (index: Int, data: Array[T]) => results(index) = data, - results.flatten.toSeq) - } + val results = new Array[Array[T]](self.partitions.size) + self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size), + (index, data) => results(index) = data, results.flatten.toSeq) } /** @@ -106,34 +104,24 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } results.toSeq } + + f } /** * Applies a function f to all elements of this RDD. */ - def foreachAsync(expr: T => Unit): FutureAction[Unit] = { - val f = new ComplexFutureAction[Unit] - val exprClean = self.context.clean(expr) - f.run { - f.runJob(self, - (iter: Iterator[T]) => iter.foreach(exprClean), - Range(0, self.partitions.size), - (index: Int, data: Unit) => Unit, - Unit) - } + def foreachAsync(f: T => Unit): FutureAction[Unit] = { + val cleanF = self.context.clean(f) + self.context.submitJob[T, Unit, Unit](self, _.foreach(cleanF), Range(0, self.partitions.size), + (index, data) => Unit, Unit) } /** * Applies a function f to each partition of this RDD. */ - def foreachPartitionAsync(expr: Iterator[T] => Unit): FutureAction[Unit] = { - val f = new ComplexFutureAction[Unit] - f.run { - f.runJob(self, - expr, - Range(0, self.partitions.size), - (index: Int, data: Unit) => Unit, - Unit) - } + def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = { + self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size), + (index, data) => Unit, Unit) } } From 25164a89dd32eef58d9b6823ae259439f796e81a Mon Sep 17 00:00:00 2001 From: Jim Lim Date: Sun, 28 Sep 2014 19:04:24 -0700 Subject: [PATCH 35/50] SPARK-2761 refactor #maybeSpill into Spillable Moved `#maybeSpill` in ExternalSorter and EAOM into `Spillable`. Author: Jim Lim Closes #2416 from jimjh/SPARK-2761 and squashes the following commits: cf8be9a [Jim Lim] SPARK-2761 fix documentation, reorder code f94d522 [Jim Lim] SPARK-2761 refactor Spillable to simplify sig e75a24e [Jim Lim] SPARK-2761 use protected over protected[this] 7270e0d [Jim Lim] SPARK-2761 refactor #maybeSpill into Spillable --- .../collection/ExternalAppendOnlyMap.scala | 46 ++------ .../util/collection/ExternalSorter.scala | 68 +++-------- .../spark/util/collection/Spillable.scala | 111 ++++++++++++++++++ 3 files changed, 133 insertions(+), 92 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/collection/Spillable.scala diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 8a015c1d26a96..0c088da46aa5e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -66,23 +66,19 @@ class ExternalAppendOnlyMap[K, V, C]( mergeCombiners: (C, C) => C, serializer: Serializer = SparkEnv.get.serializer, blockManager: BlockManager = SparkEnv.get.blockManager) - extends Iterable[(K, C)] with Serializable with Logging { + extends Iterable[(K, C)] + with Serializable + with Logging + with Spillable[SizeTracker] { private var currentMap = new SizeTrackingAppendOnlyMap[K, C] private val spilledMaps = new ArrayBuffer[DiskMapIterator] private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager // Number of pairs inserted since last spill; note that we count them even if a value is merged // with a previous key in case we're doing something like groupBy where the result grows - private var elementsRead = 0L - - // Number of in-memory pairs inserted before tracking the map's shuffle memory usage - private val trackMemoryThreshold = 1000 - - // How much of the shared memory pool this collection has claimed - private var myMemoryThreshold = 0L + protected[this] var elementsRead = 0L /** * Size of object batches when reading/writing from serializers. @@ -95,11 +91,7 @@ class ExternalAppendOnlyMap[K, V, C]( */ private val serializerBatchSize = sparkConf.getLong("spark.shuffle.spill.batchSize", 10000) - // How many times we have spilled so far - private var spillCount = 0 - // Number of bytes spilled in total - private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024 @@ -136,19 +128,8 @@ class ExternalAppendOnlyMap[K, V, C]( while (entries.hasNext) { curEntry = entries.next() - if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && - currentMap.estimateSize() >= myMemoryThreshold) - { - // Claim up to double our current memory from the shuffle memory pool - val currentMemory = currentMap.estimateSize() - val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) - myMemoryThreshold += granted - if (myMemoryThreshold <= currentMemory) { - // We were granted too little memory to grow further (either tryToAcquire returned 0, - // or we already had more memory than myMemoryThreshold); spill the current collection - spill(currentMemory) // Will also release memory back to ShuffleMemoryManager - } + if (maybeSpill(currentMap, currentMap.estimateSize())) { + currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) elementsRead += 1 @@ -171,11 +152,7 @@ class ExternalAppendOnlyMap[K, V, C]( /** * Sort the existing contents of the in-memory map and spill them to a temporary file on disk. */ - private def spill(mapSize: Long): Unit = { - spillCount += 1 - val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" - .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) + override protected[this] def spill(collection: SizeTracker): Unit = { val (blockId, file) = diskBlockManager.createTempBlock() curWriteMetrics = new ShuffleWriteMetrics() var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize, @@ -231,18 +208,11 @@ class ExternalAppendOnlyMap[K, V, C]( } } - currentMap = new SizeTrackingAppendOnlyMap[K, C] spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - // Release our memory back to the shuffle pool so that other threads can grab it - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0L - elementsRead = 0 - _memoryBytesSpilled += mapSize } - def memoryBytesSpilled: Long = _memoryBytesSpilled def diskBytesSpilled: Long = _diskBytesSpilled /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 782b979e2e93d..0a152cb97ad9e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -79,14 +79,14 @@ private[spark] class ExternalSorter[K, V, C]( aggregator: Option[Aggregator[K, V, C]] = None, partitioner: Option[Partitioner] = None, ordering: Option[Ordering[K]] = None, - serializer: Option[Serializer] = None) extends Logging { + serializer: Option[Serializer] = None) + extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] { private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1) private val shouldPartition = numPartitions > 1 private val blockManager = SparkEnv.get.blockManager private val diskBlockManager = blockManager.diskBlockManager - private val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager private val ser = Serializer.getSerializer(serializer) private val serInstance = ser.newInstance() @@ -115,22 +115,14 @@ private[spark] class ExternalSorter[K, V, C]( // Number of pairs read from input since last spill; note that we count them even if a value is // merged with a previous key in case we're doing something like groupBy where the result grows - private var elementsRead = 0L - - // What threshold of elementsRead we start estimating map size at. - private val trackMemoryThreshold = 1000 + protected[this] var elementsRead = 0L // Total spilling statistics - private var spillCount = 0 - private var _memoryBytesSpilled = 0L private var _diskBytesSpilled = 0L // Write metrics for current spill private var curWriteMetrics: ShuffleWriteMetrics = _ - // How much of the shared memory pool this collection has claimed - private var myMemoryThreshold = 0L - // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need // local aggregation and sorting, write numPartitions files directly and just concatenate them // at the end. This avoids doing serialization and deserialization twice to merge together the @@ -209,7 +201,7 @@ private[spark] class ExternalSorter[K, V, C]( elementsRead += 1 kv = records.next() map.changeValue((getPartition(kv._1), kv._1), update) - maybeSpill(usingMap = true) + maybeSpillCollection(usingMap = true) } } else { // Stick values into our buffer @@ -217,7 +209,7 @@ private[spark] class ExternalSorter[K, V, C]( elementsRead += 1 val kv = records.next() buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) - maybeSpill(usingMap = false) + maybeSpillCollection(usingMap = false) } } } @@ -227,61 +219,31 @@ private[spark] class ExternalSorter[K, V, C]( * * @param usingMap whether we're using a map or buffer as our current in-memory collection */ - private def maybeSpill(usingMap: Boolean): Unit = { + private def maybeSpillCollection(usingMap: Boolean): Unit = { if (!spillingEnabled) { return } - val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer - - // TODO: factor this out of both here and ExternalAppendOnlyMap - if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && - collection.estimateSize() >= myMemoryThreshold) - { - // Claim up to double our current memory from the shuffle memory pool - val currentMemory = collection.estimateSize() - val amountToRequest = 2 * currentMemory - myMemoryThreshold - val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) - myMemoryThreshold += granted - if (myMemoryThreshold <= currentMemory) { - // We were granted too little memory to grow further (either tryToAcquire returned 0, - // or we already had more memory than myMemoryThreshold); spill the current collection - spill(currentMemory, usingMap) // Will also release memory back to ShuffleMemoryManager + if (usingMap) { + if (maybeSpill(map, map.estimateSize())) { + map = new SizeTrackingAppendOnlyMap[(Int, K), C] + } + } else { + if (maybeSpill(buffer, buffer.estimateSize())) { + buffer = new SizeTrackingPairBuffer[(Int, K), C] } } } /** * Spill the current in-memory collection to disk, adding a new file to spills, and clear it. - * - * @param usingMap whether we're using a map or buffer as our current in-memory collection */ - private def spill(memorySize: Long, usingMap: Boolean): Unit = { - val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer - val memorySize = collection.estimateSize() - - spillCount += 1 - val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)" - .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else "")) - + override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = { if (bypassMergeSort) { spillToPartitionFiles(collection) } else { spillToMergeableFile(collection) } - - if (usingMap) { - map = new SizeTrackingAppendOnlyMap[(Int, K), C] - } else { - buffer = new SizeTrackingPairBuffer[(Int, K), C] - } - - // Release our memory back to the shuffle pool so that other threads can grab it - shuffleMemoryManager.release(myMemoryThreshold) - myMemoryThreshold = 0 - - _memoryBytesSpilled += memorySize } /** @@ -804,8 +766,6 @@ private[spark] class ExternalSorter[K, V, C]( } } - def memoryBytesSpilled: Long = _memoryBytesSpilled - def diskBytesSpilled: Long = _diskBytesSpilled /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala new file mode 100644 index 0000000000000..d7dccd4af8c6e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.apache.spark.Logging +import org.apache.spark.SparkEnv + +/** + * Spills contents of an in-memory collection to disk when the memory threshold + * has been exceeded. + */ +private[spark] trait Spillable[C] { + + this: Logging => + + /** + * Spills the current in-memory collection to disk, and releases the memory. + * + * @param collection collection to spill to disk + */ + protected def spill(collection: C): Unit + + // Number of elements read from input since last spill + protected var elementsRead: Long + + // Memory manager that can be used to acquire/release memory + private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager + + // What threshold of elementsRead we start estimating collection size at + private[this] val trackMemoryThreshold = 1000 + + // How much of the shared memory pool this collection has claimed + private[this] var myMemoryThreshold = 0L + + // Number of bytes spilled in total + private[this] var _memoryBytesSpilled = 0L + + // Number of spills + private[this] var _spillCount = 0 + + /** + * Spills the current in-memory collection to disk if needed. Attempts to acquire more + * memory before spilling. + * + * @param collection collection to spill to disk + * @param currentMemory estimated size of the collection in bytes + * @return true if `collection` was spilled to disk; false otherwise + */ + protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { + if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 && + currentMemory >= myMemoryThreshold) { + // Claim up to double our current memory from the shuffle memory pool + val amountToRequest = 2 * currentMemory - myMemoryThreshold + val granted = shuffleMemoryManager.tryToAcquire(amountToRequest) + myMemoryThreshold += granted + if (myMemoryThreshold <= currentMemory) { + // We were granted too little memory to grow further (either tryToAcquire returned 0, + // or we already had more memory than myMemoryThreshold); spill the current collection + _spillCount += 1 + logSpillage(currentMemory) + + spill(collection) + + // Keep track of spills, and release memory + _memoryBytesSpilled += currentMemory + releaseMemoryForThisThread() + return true + } + } + false + } + + /** + * @return number of bytes spilled in total + */ + def memoryBytesSpilled: Long = _memoryBytesSpilled + + /** + * Release our memory back to the shuffle pool so that other threads can grab it. + */ + private def releaseMemoryForThisThread(): Unit = { + shuffleMemoryManager.release(myMemoryThreshold) + myMemoryThreshold = 0L + } + + /** + * Prints a standard log message detailing spillage. + * + * @param size number of bytes spilled + */ + @inline private def logSpillage(size: Long) { + val threadId = Thread.currentThread().getId + logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)" + .format(threadId, size / (1024 * 1024), _spillCount, if (_spillCount > 1) "s" else "")) + } +} From f350cd307045c2c02e713225d8f1247f18ba123e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 28 Sep 2014 20:32:54 -0700 Subject: [PATCH 36/50] [SPARK-3543] TaskContext remaining cleanup work. Author: Reynold Xin Closes #2560 from rxin/TaskContext and squashes the following commits: 9eff95a [Reynold Xin] [SPARK-3543] remaining cleanup work. --- core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala | 2 +- .../main/scala/org/apache/spark/rdd/PairRDDFunctions.scala | 3 ++- .../apache/spark/util/JavaTaskCompletionListenerImpl.java | 7 +++---- .../serializer/ProactiveClosureSerializationSuite.scala | 6 +----- .../apache/spark/sql/parquet/ParquetTableOperations.scala | 4 ++-- 5 files changed, 9 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 036dcc49664ef..21d0cc7b5cbaa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -194,7 +194,7 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() val inputFormat = getInputFormat(jobConf) HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime), - context.stageId, theSplit.index, context.attemptId.toInt, jobConf) + context.getStageId, theSplit.index, context.getAttemptId.toInt, jobConf) reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL) // Register an on-task-completion callback to close the input stream. diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 7f578bc5dac39..67833743f3a98 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -86,7 +86,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) } val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) if (self.partitioner == Some(partitioner)) { - self.mapPartitionsWithContext((context, iter) => { + self.mapPartitions(iter => { + val context = TaskContext.get() new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context)) }, preservesPartitioning = true) } else { diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index af34cdb03e4d1..0944bf8cd5c71 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -30,10 +30,9 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); - context.stageId(); - context.partitionId(); - context.runningLocally(); - context.taskMetrics(); + context.getStageId(); + context.getPartitionId(); + context.isRunningLocally(); context.addTaskCompletionListener(this); } } diff --git a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala index aad6599589420..d037e2c19a64d 100644 --- a/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/ProactiveClosureSerializationSuite.scala @@ -50,8 +50,7 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex "flatMap" -> xflatMap _, "filter" -> xfilter _, "mapPartitions" -> xmapPartitions _, - "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _, - "mapPartitionsWithContext" -> xmapPartitionsWithContext _)) { + "mapPartitionsWithIndex" -> xmapPartitionsWithIndex _)) { val (name, xf) = transformation test(s"$name transformations throw proactive serialization exceptions") { @@ -78,8 +77,5 @@ class ProactiveClosureSerializationSuite extends FunSuite with SharedSparkContex private def xmapPartitionsWithIndex(x: RDD[String], uc: UnserializableClass): RDD[String] = x.mapPartitionsWithIndex((_, it) => it.map(y=>uc.op(y))) - - private def xmapPartitionsWithContext(x: RDD[String], uc: UnserializableClass): RDD[String] = - x.mapPartitionsWithContext((_, it) => it.map(y=>uc.op(y))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index d39e31a7fa195..ffb732347d30a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -289,9 +289,9 @@ case class InsertIntoParquetTable( def writeShard(context: TaskContext, iter: Iterator[Row]): Int = { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. - val attemptNumber = (context.attemptId % Int.MaxValue).toInt + val attemptNumber = (context.getAttemptId % Int.MaxValue).toInt /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.getPartitionId, attemptNumber) val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId) val format = new AppendingParquetOutputFormat(taskIdOffset) From 0dc2b6361d61b7d94cba3dc83da2abb7e08ba6fe Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sun, 28 Sep 2014 21:44:50 -0700 Subject: [PATCH 37/50] [SPARK-1545] [mllib] Add Random Forests This PR adds RandomForest to MLlib. The implementation is basic, and future performance optimizations will be important. (Note: RFs = Random Forests.) # Overview ## RandomForest * trains multiple trees at once to reduce the number of passes over the data * allows feature subsets at each node * uses a queue of nodes instead of fixed groups for each level This implementation is based an implementation by manishamde and the [Alpine Labs Sequoia Forest](https://github.com/AlpineNow/SparkML2) by codedeft (in particular, the TreePoint, BaggedPoint, and node queue implementations). Thank you for your inputs! ## Testing Correctness: This has been tested for correctness with the test suites and with DecisionTreeRunner on example datasets. Performance: This has been performance tested using [this branch of spark-perf](https://github.com/jkbradley/spark-perf/tree/rfs). Results below. ### Regression tests for DecisionTree Summary: For training 1 tree, there are small regressions, especially from feature subsampling. In the table below, each row is a single (random) dataset. The 2 different sets of result columns are for 2 different RF implementations: * (numTrees): This is from an earlier commit, after implementing RandomForest to train multiple trees at once. It does not include any code for feature subsampling. * (feature subsets): This is from this current PR's code, after implementing feature subsampling. These tests were to identify regressions in DecisionTree, so they are training 1 tree with all of the features (i.e., no feature subsampling). These were run on an EC2 cluster with 15 workers, training 1 tree with maxDepth = 5 (= 6 levels). Speedup values < 1 indicate slowdowns from the old DecisionTree implementation. numInstances | numFeatures | runtime (sec) | speedup | runtime (sec) | speedup ---- | ---- | ---- | ---- | ---- | ---- | | (numTrees) | (numTrees) | (feature subsets) | (feature subsets) 20000 | 100 | 4.051 | 1.044433473 | 4.478 | 0.9448414471 20000 | 500 | 8.472 | 1.104461756 | 9.315 | 1.004508857 20000 | 1500 | 19.354 | 1.05854087 | 20.863 | 0.9819776638 20000 | 3500 | 43.674 | 1.072033704 | 45.887 | 1.020332556 200000 | 100 | 4.196 | 1.171830315 | 4.848 | 1.014232673 200000 | 500 | 8.926 | 1.082791844 | 9.771 | 0.989151571 200000 | 1500 | 20.58 | 1.068415938 | 22.134 | 0.9934038131 200000 | 3500 | 48.043 | 1.075203464 | 52.249 | 0.9886505005 2000000 | 100 | 4.944 | 1.01355178 | 5.796 | 0.8645617667 2000000 | 500 | 11.11 | 1.016831683 | 12.482 | 0.9050632911 2000000 | 1500 | 31.144 | 1.017852556 | 35.274 | 0.8986789136 2000000 | 3500 | 79.981 | 1.085382778 | 101.105 | 0.8586123337 20000000 | 100 | 8.304 | 0.9270231214 | 9.073 | 0.8484514494 20000000 | 500 | 28.174 | 1.083268262 | 34.236 | 0.8914592826 20000000 | 1500 | 143.97 | 0.9579634646 | 159.275 | 0.8659111599 ### Tests for forests I have run other tests with numTrees=10 and with sqrt(numFeatures), and those indicate that multi-model training and feature subsets can speed up training for forests, especially when training deeper trees. # Details on specific classes ## Changes to DecisionTree * Main train() method is now in RandomForest. * findBestSplits() is no longer needed. (It split levels into groups, but we now use a queue of nodes.) * Many small changes to support RFs. (Note: These methods should be moved to RandomForest.scala in a later PR, but are in DecisionTree.scala to make code comparison easier.) ## RandomForest * Main train() method is from old DecisionTree. * selectNodesToSplit: Note that it selects nodes and feature subsets jointly to track memory usage. ## RandomForestModel * Stores an Array[DecisionTreeModel] * Prediction: * For classification, most common label. For regression, mean. * We could support other methods later. ## examples/.../DecisionTreeRunner * This now takes numTrees and featureSubsetStrategy, to support RFs. ## DTStatsAggregator * 2 types of functionality (w/ and w/o subsampling features): These require different indexing methods. (We could treat both as subsampling, but this is less efficient DTStatsAggregator is now abstract, and 2 child classes implement these 2 types of functionality. ## impurities * These now take instance weights. ## Node * Some vals changed to vars. * This is unfortunately a public API change (DeveloperApi). This could be avoided by creating a LearningNode struct, but would be awkward. ## RandomForestSuite Please let me know if there are missing tests! ## BaggedPoint This wraps TreePoint and holds bootstrap weights/counts. # Design decisions * BaggedPoint: BaggedPoint is separate from TreePoint since it may be useful for other bagging algorithms later on. * RandomForest public API: What options should be easily supported by the train* methods? Should ALL options be in the Java-friendly constructors? Should there be a constructor taking Strategy? * Feature subsampling options: What options should be supported? scikit-learn supports the same options, except for "onethird." One option would be to allow users to specific fractions ("0.1"): the current options could be supported, and any unrecognized values would be parsed as Doubles in [0,1]. * Splits and bins are computed before bootstrapping, so all trees use the same discretization. * One queue, instead of one queue per tree. CC: mengxr manishamde codedeft chouqin Please let me know if you have suggestions---thanks! Author: Joseph K. Bradley Author: qiping.lqp Author: chouqin Closes #2435 from jkbradley/rfs-new and squashes the following commits: c694174 [Joseph K. Bradley] Fixed typo cc59d78 [Joseph K. Bradley] fixed imports e25909f [Joseph K. Bradley] Simplified node group maps. Specifically, created NodeIndexInfo to store node index in agg and feature subsets, and no longer create extra maps in findBestSplits fbe9a1e [Joseph K. Bradley] Changed default featureSubsetStrategy to be sqrt for classification, onethird for regression. Updated docs with references. ef7c293 [Joseph K. Bradley] Updates based on code review. Most substantial changes: * Simplified DTStatsAggregator * Made RandomForestModel.trees public * Added test for regression to RandomForestSuite 593b13c [Joseph K. Bradley] Fixed bug in metadata for computing log2(num features). Now it checks >= 1. a1a08df [Joseph K. Bradley] Removed old comments 866e766 [Joseph K. Bradley] Changed RandomForestSuite randomized tests to use multiple fixed random seeds. ff8bb96 [Joseph K. Bradley] removed usage of null from RandomForest and replaced with Option bf1a4c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 6b79c07 [Joseph K. Bradley] Added RandomForestSuite, and fixed small bugs, style issues. d7753d4 [Joseph K. Bradley] Added numTrees and featureSubsetStrategy to DecisionTreeRunner (to support RandomForest). Fixed bugs so that RandomForest now runs. 746d43c [Joseph K. Bradley] Implemented feature subsampling. Tested DecisionTree but not RandomForest. 6309d1d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new. Added RandomForestModel.toString b7ae594 [Joseph K. Bradley] Updated docs. Small fix for bug which does not cause errors: No longer allocate unused child nodes for leaf nodes. 121c74e [Joseph K. Bradley] Basic random forests are implemented. Random features per node not yet implemented. Test suite not implemented. 325d18a [Joseph K. Bradley] Merge branch 'chouqin-dt-preprune' into rfs-new 4ef9bf1 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 61b2e72 [Joseph K. Bradley] Added max of 10GB for maxMemoryInMB in Strategy. a95e7c8 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 6da8571 [Joseph K. Bradley] RFs partly implemented, not done yet eddd1eb [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into rfs-new 5c4ac33 [Joseph K. Bradley] Added check in Strategy to make sure minInstancesPerNode >= 1 0dd4d87 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 95c479d [Joseph K. Bradley] * Fixed typo in tree suite test "do not choose split that does not satisfy min instance per node requirements" * small style fixes e2628b6 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into chouqin-dt-preprune 19b01af [Joseph K. Bradley] Merge remote-tracking branch 'chouqin/dt-preprune' into chouqin-dt-preprune f1d11d1 [chouqin] fix typo c7ebaf1 [chouqin] fix typo 39f9b60 [chouqin] change edge `minInstancesPerNode` to 2 and add one more test c6e2dfc [Joseph K. Bradley] Added minInstancesPerNode and minInfoGain parameters to DecisionTreeRunner.scala and to Python API in tree.py 306120f [Joseph K. Bradley] Fixed typo in DecisionTreeModel.scala doc eaa1dcf [Joseph K. Bradley] Added topNode doc in DecisionTree and scalastyle fix d4d7864 [Joseph K. Bradley] Marked Node.build as deprecated d4dbb99 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-spark-3160 1a8f0ad [Joseph K. Bradley] Eliminated pre-allocated nodes array in main train() method. * Nodes are constructed and added to the tree structure as needed during training. 0278a11 [chouqin] remove `noSplit` and set `Predict` private to tree d593ec7 [chouqin] fix docs and change minInstancesPerNode to 1 2ab763b [Joseph K. Bradley] Simplifications to DecisionTree code: efcc736 [qiping.lqp] fix bug 10b8012 [qiping.lqp] fix style 6728fad [qiping.lqp] minor fix: remove empty lines bb465ca [qiping.lqp] Merge branch 'master' of https://github.com/apache/spark into dt-preprune cadd569 [qiping.lqp] add api docs 46b891f [qiping.lqp] fix bug e72c7e4 [qiping.lqp] add comments 845c6fa [qiping.lqp] fix style f195e83 [qiping.lqp] fix style 987cbf4 [qiping.lqp] fix bug ff34845 [qiping.lqp] separate calculation of predict of node from calculation of info gain ac42378 [qiping.lqp] add min info gain and min instances per node parameters in decision tree --- .../examples/mllib/DecisionTreeRunner.scala | 76 ++- .../spark/mllib/tree/DecisionTree.scala | 457 ++++++------------ .../spark/mllib/tree/RandomForest.scala | 451 +++++++++++++++++ .../spark/mllib/tree/impl/BaggedPoint.scala | 80 +++ .../mllib/tree/impl/DTStatsAggregator.scala | 219 +++++++-- .../tree/impl/DecisionTreeMetadata.scala | 47 +- .../spark/mllib/tree/impurity/Entropy.scala | 4 +- .../spark/mllib/tree/impurity/Gini.scala | 4 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 8 +- .../apache/spark/mllib/tree/model/Node.scala | 13 +- .../mllib/tree/model/RandomForestModel.scala | 105 ++++ .../spark/mllib/tree/DecisionTreeSuite.scala | 210 ++++---- .../spark/mllib/tree/RandomForestSuite.scala | 245 ++++++++++ 14 files changed, 1410 insertions(+), 511 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index 4683e6eb966be..96fb068e9e126 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -21,16 +21,18 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree, impurity} +import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.model.DecisionTreeModel +import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** - * An example runner for decision tree. Run with + * An example runner for decision trees and random forests. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DecisionTreeRunner [options] * }}} @@ -57,6 +59,8 @@ object DecisionTreeRunner { maxBins: Int = 32, minInstancesPerNode: Int = 1, minInfoGain: Double = 0.0, + numTrees: Int = 1, + featureSubsetStrategy: String = "auto", fracTest: Double = 0.2) def main(args: Array[String]) { @@ -79,11 +83,20 @@ object DecisionTreeRunner { .action((x, c) => c.copy(maxBins = x)) opt[Int]("minInstancesPerNode") .text(s"min number of instances required at child nodes to create the parent split," + - s" default: ${defaultParams.minInstancesPerNode}") + s" default: ${defaultParams.minInstancesPerNode}") .action((x, c) => c.copy(minInstancesPerNode = x)) opt[Double]("minInfoGain") .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}") .action((x, c) => c.copy(minInfoGain = x)) + opt[Int]("numTrees") + .text(s"number of trees (1 = decision tree, 2+ = random forest)," + + s" default: ${defaultParams.numTrees}") + .action((x, c) => c.copy(numTrees = x)) + opt[String]("featureSubsetStrategy") + .text(s"feature subset sampling strategy" + + s" (${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}}), " + + s"default: ${defaultParams.featureSubsetStrategy}") + .action((x, c) => c.copy(featureSubsetStrategy = x)) opt[Double]("fracTest") .text(s"fraction of data to hold out for testing, default: ${defaultParams.fracTest}") .action((x, c) => c.copy(fracTest = x)) @@ -191,18 +204,35 @@ object DecisionTreeRunner { numClassesForClassification = numClasses, minInstancesPerNode = params.minInstancesPerNode, minInfoGain = params.minInfoGain) - val model = DecisionTree.train(training, strategy) - - println(model) - - if (params.algo == Classification) { - val accuracy = accuracyScore(model, test) - println(s"Test accuracy = $accuracy") - } - - if (params.algo == Regression) { - val mse = meanSquaredError(model, test) - println(s"Test mean squared error = $mse") + if (params.numTrees == 1) { + val model = DecisionTree.train(training, strategy) + println(model) + if (params.algo == Classification) { + val accuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $accuracy") + } + if (params.algo == Regression) { + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse") + } + } else { + val randomSeed = Utils.random.nextInt() + if (params.algo == Classification) { + val model = RandomForest.trainClassifier(training, strategy, params.numTrees, + params.featureSubsetStrategy, randomSeed) + println(model) + val accuracy = + new MulticlassMetrics(test.map(lp => (model.predict(lp.features), lp.label))).precision + println(s"Test accuracy = $accuracy") + } + if (params.algo == Regression) { + val model = RandomForest.trainRegressor(training, strategy, params.numTrees, + params.featureSubsetStrategy, randomSeed) + println(model) + val mse = meanSquaredError(model, test) + println(s"Test mean squared error = $mse") + } } sc.stop() @@ -211,9 +241,7 @@ object DecisionTreeRunner { /** * Calculates the classifier accuracy. */ - private def accuracyScore( - model: DecisionTreeModel, - data: RDD[LabeledPoint]): Double = { + private def accuracyScore(model: DecisionTreeModel, data: RDD[LabeledPoint]): Double = { val correctCount = data.filter(y => model.predict(y.features) == y.label).count() val count = data.count() correctCount.toDouble / count @@ -228,4 +256,14 @@ object DecisionTreeRunner { err * err }.mean() } + + /** + * Calculates the mean squared error for regression. + */ + private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = { + data.map { y => + val err = tree.predict(y.features) - y.label + err * err + }.mean() + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index c7f2576c822b1..b7dc373ebd9cc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -18,12 +18,14 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.Logging import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.configuration.Strategy import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ @@ -33,7 +35,6 @@ import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom @@ -56,99 +57,10 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * @return DecisionTreeModel that can be used for prediction */ def train(input: RDD[LabeledPoint]): DecisionTreeModel = { - - val timer = new TimeTracker() - - timer.start("total") - - timer.start("init") - - val retaggedInput = input.retag(classOf[LabeledPoint]) - val metadata = DecisionTreeMetadata.buildMetadata(retaggedInput, strategy) - logDebug("algo = " + strategy.algo) - logDebug("maxBins = " + metadata.maxBins) - - // Find the splits and the corresponding bins (interval between the splits) using a sample - // of the input data. - timer.start("findSplitsBins") - val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) - timer.stop("findSplitsBins") - logDebug("numBins: feature: number of bins") - logDebug(Range(0, metadata.numFeatures).map { featureIndex => - s"\t$featureIndex\t${metadata.numBins(featureIndex)}" - }.mkString("\n")) - - // Bin feature values (TreePoint representation). - // Cache input RDD for speedup during multiple passes. - val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) - .persist(StorageLevel.MEMORY_AND_DISK) - - // depth of the decision tree - val maxDepth = strategy.maxDepth - require(maxDepth <= 30, - s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") - - // Calculate level for single group construction - - // Max memory usage for aggregates - val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L - logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") - // TODO: Calculate memory usage more precisely. - val numElementsPerNode = DecisionTree.getElementsPerNode(metadata) - - logDebug("numElementsPerNode = " + numElementsPerNode) - val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array - val maxNumberOfNodesPerGroup = math.max(maxMemoryUsage / arraySizePerNode, 1) - logDebug("maxNumberOfNodesPerGroup = " + maxNumberOfNodesPerGroup) - // nodes at a level is 2^level. level is zero indexed. - val maxLevelForSingleGroup = math.max( - (math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0) - logDebug("max level for single group = " + maxLevelForSingleGroup) - - timer.stop("init") - - /* - * The main idea here is to perform level-wise training of the decision tree nodes thus - * reducing the passes over the data from l to log2(l) where l is the total number of nodes. - * Each data sample is handled by a particular node at that level (or it reaches a leaf - * beforehand and is not used in later levels. - */ - - var topNode: Node = null // set on first iteration - var level = 0 - var break = false - while (level <= maxDepth && !break) { - logDebug("#####################################") - logDebug("level = " + level) - logDebug("#####################################") - - // Find best split for all nodes at a level. - timer.start("findBestSplits") - val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput, - metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer) - timer.stop("findBestSplits") - - if (level == 0) { - topNode = tmpTopNode - } - if (doneTraining) { - break = true - logDebug("done training") - } - - level += 1 - } - - logDebug("#####################################") - logDebug("Extracting tree model") - logDebug("#####################################") - - timer.stop("total") - - logInfo("Internal timing for DecisionTree:") - logInfo(s"$timer") - - new DecisionTreeModel(topNode, strategy.algo) + // Note: random seed will not be used since numTrees = 1. + val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0) + val rfModel = rf.train(input) + rfModel.trees(0) } } @@ -352,58 +264,10 @@ object DecisionTree extends Serializable with Logging { impurity, maxDepth, maxBins) } - /** - * Returns an array of optimal splits for all nodes at a given level. Splits the task into - * multiple groups if the level-wise training task could lead to memory overflow. - * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] - * @param metadata Learning and dataset metadata - * @param level Level of the tree - * @param topNode Root node of the tree (or invalid node when training first level). - * @param splits possible splits for all features, indexed (numFeatures)(numSplits) - * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation. - * @return (root, doneTraining) where: - * root = Root node (which is newly created on the first iteration), - * doneTraining = true if no more internal nodes were created. - */ - private[tree] def findBestSplits( - input: RDD[TreePoint], - metadata: DecisionTreeMetadata, - level: Int, - topNode: Node, - splits: Array[Array[Split]], - bins: Array[Array[Bin]], - maxLevelForSingleGroup: Int, - timer: TimeTracker = new TimeTracker): (Node, Boolean) = { - - // split into groups to avoid memory overflow during aggregation - if (level > maxLevelForSingleGroup) { - // When information for all nodes at a given level cannot be stored in memory, - // the nodes are divided into multiple groups at each level with the number of groups - // increasing exponentially per level. For example, if maxLevelForSingleGroup is 10, - // numGroups is equal to 2 at level 11 and 4 at level 12, respectively. - val numGroups = 1 << level - maxLevelForSingleGroup - logDebug("numGroups = " + numGroups) - // Iterate over each group of nodes at a level. - var groupIndex = 0 - var doneTraining = true - while (groupIndex < numGroups) { - val (_, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level, - topNode, splits, bins, timer, numGroups, groupIndex) - doneTraining = doneTraining && doneTrainingGroup - groupIndex += 1 - } - (topNode, doneTraining) // Not first iteration, so topNode was already set. - } else { - findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer) - } - } - /** * Get the node index corresponding to this data point. - * This function mimics prediction, passing an example from the root node down to a node - * at the current level being trained; that node's index is returned. + * This function mimics prediction, passing an example from the root node down to a leaf + * or unsplit node; that node's index is returned. * * @param node Node in tree from which to classify the given data point. * @param binnedFeatures Binned feature vector for data point. @@ -413,14 +277,15 @@ object DecisionTree extends Serializable with Logging { * Otherwise, last node reachable in tree matching this example. * Note: This is the global node index, i.e., the index used in the tree. * This index is different from the index used during training a particular - * set of nodes in a (level, group). + * group of nodes on one call to [[findBestSplits()]]. */ private def predictNodeIndex( node: Node, binnedFeatures: Array[Int], bins: Array[Array[Bin]], unorderedFeatures: Set[Int]): Int = { - if (node.isLeaf) { + if (node.isLeaf || node.split.isEmpty) { + // Node is either leaf, or has not yet been split. node.id } else { val featureIndex = node.split.get.feature @@ -465,43 +330,60 @@ object DecisionTree extends Serializable with Logging { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). + * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param unorderedFeatures Set of indices of unordered features. + * @param instanceWeight Weight (importance) of instance in dataset. */ private def mixedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, nodeIndex: Int, bins: Array[Array[Bin]], - unorderedFeatures: Set[Int]): Unit = { - // Iterate over all features. - val numFeatures = treePoint.binnedFeatures.size + unorderedFeatures: Set[Int], + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { + val numFeaturesPerNode = if (featuresForNode.nonEmpty) { + // Use subsampled features + featuresForNode.get.size + } else { + // Use all features + agg.metadata.numFeatures + } val nodeOffset = agg.getNodeOffset(nodeIndex) - var featureIndex = 0 - while (featureIndex < numFeatures) { + // Iterate over features. + var featureIndexIdx = 0 + while (featureIndexIdx < numFeaturesPerNode) { + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + agg.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) // Update the left or right bin for each split. - val numSplits = agg.numSplits(featureIndex) + val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 while (splitIndex < numSplits) { if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) { - agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label) + agg.nodeFeatureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, + instanceWeight) } else { - agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label) + agg.nodeFeatureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, + instanceWeight) } splitIndex += 1 } } else { // Ordered feature val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, featureIndex, binIndex, treePoint.label) + agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, treePoint.label, + instanceWeight) } - featureIndex += 1 + featureIndexIdx += 1 } } @@ -513,66 +395,77 @@ object DecisionTree extends Serializable with Logging { * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). * @param treePoint Data point being aggregated. - * @param nodeIndex Node corresponding to treePoint. Indexed from 0 at start of (level, group). - * @return agg + * @param nodeIndex Node corresponding to treePoint. agg is indexed in [0, numNodes). + * @param instanceWeight Weight (importance) of instance in dataset. */ private def orderedBinSeqOp( agg: DTStatsAggregator, treePoint: TreePoint, - nodeIndex: Int): Unit = { + nodeIndex: Int, + instanceWeight: Double, + featuresForNode: Option[Array[Int]]): Unit = { val label = treePoint.label val nodeOffset = agg.getNodeOffset(nodeIndex) - // Iterate over all features. - val numFeatures = agg.numFeatures - var featureIndex = 0 - while (featureIndex < numFeatures) { - val binIndex = treePoint.binnedFeatures(featureIndex) - agg.nodeUpdate(nodeOffset, featureIndex, binIndex, label) - featureIndex += 1 + // Iterate over features. + if (featuresForNode.nonEmpty) { + // Use subsampled features + var featureIndexIdx = 0 + while (featureIndexIdx < featuresForNode.get.size) { + val binIndex = treePoint.binnedFeatures(featuresForNode.get.apply(featureIndexIdx)) + agg.nodeUpdate(nodeOffset, nodeIndex, featureIndexIdx, binIndex, label, instanceWeight) + featureIndexIdx += 1 + } + } else { + // Use all features + val numFeatures = agg.metadata.numFeatures + var featureIndex = 0 + while (featureIndex < numFeatures) { + val binIndex = treePoint.binnedFeatures(featureIndex) + agg.nodeUpdate(nodeOffset, nodeIndex, featureIndex, binIndex, label, instanceWeight) + featureIndex += 1 + } } } /** - * Returns an array of optimal splits for a group of nodes at a given level + * Given a group of nodes, this finds the best split for each node. * * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] * @param metadata Learning and dataset metadata - * @param level Level of the tree - * @param topNode Root node of the tree (or invalid node when training first level). + * @param topNodes Root node for each tree. Used for matching instances with nodes. + * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). * @param splits possible splits for all features, indexed (numFeatures)(numSplits) * @param bins possible bins for all features, indexed (numFeatures)(numBins) - * @param numGroups total number of node groups at the current level. Default value is set to 1. - * @param groupIndex index of the node group being processed. Default value is set to 0. - * @return (root, doneTraining) where: - * root = Root node (which is newly created on the first iteration), - * doneTraining = true if no more internal nodes were created. + * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). + * Updated with new non-leaf nodes which are created. */ - private def findBestSplitsPerGroup( - input: RDD[TreePoint], + private[tree] def findBestSplits( + input: RDD[BaggedPoint[TreePoint]], metadata: DecisionTreeMetadata, - level: Int, - topNode: Node, + topNodes: Array[Node], + nodesForGroup: Map[Int, Array[Node]], + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]], splits: Array[Array[Split]], bins: Array[Array[Bin]], - timer: TimeTracker, - numGroups: Int = 1, - groupIndex: Int = 0): (Node, Boolean) = { + nodeQueue: mutable.Queue[(Int, Node)], + timer: TimeTracker = new TimeTracker): Unit = { /* * The high-level descriptions of the best split optimizations are noted here. * - * *Level-wise training* - * We perform bin calculations for all nodes at the given level to avoid making multiple - * passes over the data. Thus, for a slightly increased computation and storage cost we save - * several iterations over the data especially at higher levels of the decision tree. + * *Group-wise training* + * We perform bin calculations for groups of nodes to reduce the number of + * passes over the data. Each iteration requires more computation and storage, + * but saves several iterations over the data. * * *Bin-wise computation* * We use a bin-wise best split computation strategy instead of a straightforward best split * computation strategy. Instead of analyzing each sample for contribution to the left/right * child node impurity of every split, we first categorize each feature of a sample into a - * bin. Each bin is an interval between a low and high split. Since each split, and thus bin, - * is ordered (read ordering for categorical variables in the findSplitsBins method), - * we exploit this structure to calculate aggregates for bins and then use these aggregates + * bin. We exploit this structure to calculate aggregates for bins and then use these aggregates * to calculate information gain for each split. * * *Aggregation over partitions* @@ -582,42 +475,15 @@ object DecisionTree extends Serializable with Logging { * drastically reduce the communication overhead. */ - // Common calculations for multiple nested methods: - - // numNodes: Number of nodes in this (level of tree, group), - // where nodes at deeper (larger) levels may be divided into groups. - val numNodes = Node.maxNodesInLevel(level) / numGroups + // numNodes: Number of nodes in this group + val numNodes = nodesForGroup.values.map(_.size).sum logDebug("numNodes = " + numNodes) - logDebug("numFeatures = " + metadata.numFeatures) logDebug("numClasses = " + metadata.numClasses) logDebug("isMulticlass = " + metadata.isMulticlass) logDebug("isMulticlassWithCategoricalFeatures = " + metadata.isMulticlassWithCategoricalFeatures) - // shift when more than one group is used at deep tree level - val groupShift = numNodes * groupIndex - - // Used for treePointToNodeIndex to get an index for this (level, group). - // - Node.startIndexInLevel(level) gives the global index offset for nodes at this level. - // - groupShift corrects for groups in this level before the current group. - val globalNodeIndexOffset = Node.startIndexInLevel(level) + groupShift - - /** - * Find the node index for the given example. - * Nodes are indexed from 0 at the start of this (level, group). - * If the example does not reach this level, returns a value < 0. - */ - def treePointToNodeIndex(treePoint: TreePoint): Int = { - if (level == 0) { - 0 - } else { - val globalNodeIndex = - predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures) - globalNodeIndex - globalNodeIndexOffset - } - } - /** * Performs a sequential aggregation over a partition. * @@ -626,21 +492,27 @@ object DecisionTree extends Serializable with Logging { * * @param agg Array storing aggregate calculation, with a set of sufficient statistics for * each (node, feature, bin). - * @param treePoint Data point being aggregated. + * @param baggedPoint Data point being aggregated. * @return agg */ def binSeqOp( agg: DTStatsAggregator, - treePoint: TreePoint): DTStatsAggregator = { - val nodeIndex = treePointToNodeIndex(treePoint) - // If the example does not reach this level, then nodeIndex < 0. - // If the example reaches this level but is handled in a different group, - // then either nodeIndex < 0 (previous group) or nodeIndex >= numNodes (later group). - if (nodeIndex >= 0 && nodeIndex < numNodes) { - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg, treePoint, nodeIndex) - } else { - mixedBinSeqOp(agg, treePoint, nodeIndex, bins, metadata.unorderedFeatures) + baggedPoint: BaggedPoint[TreePoint]): DTStatsAggregator = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, + bins, metadata.unorderedFeatures) + val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null) + // If the example does not reach a node in this group, then nodeIndex = null. + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg, baggedPoint.datum, aggNodeIndex, bins, metadata.unorderedFeatures, + instanceWeight, featuresForNode) + } } } agg @@ -649,71 +521,62 @@ object DecisionTree extends Serializable with Logging { // Calculate bin aggregates. timer.start("aggregation") val binAggregates: DTStatsAggregator = { - val initAgg = new DTStatsAggregator(metadata, numNodes) + val initAgg = if (metadata.subsamplingFeatures) { + new DTStatsAggregatorSubsampledFeatures(metadata, treeToNodeToIndexInfo) + } else { + new DTStatsAggregatorFixedFeatures(metadata, numNodes) + } input.treeAggregate(initAgg)(binSeqOp, DTStatsAggregator.binCombOp) } timer.stop("aggregation") - // Calculate best splits for all nodes at a given level + // Calculate best splits for all nodes in the group timer.start("chooseSplits") - // On the first iteration, we need to get and return the newly created root node. - var newTopNode: Node = topNode - - // Iterate over all nodes at this level - var nodeIndex = 0 - var internalNodeCount = 0 - while (nodeIndex < numNodes) { - val (split: Split, stats: InformationGainStats, predict: Predict) = - binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits) - logDebug("best split = " + split) - - val globalNodeIndex = globalNodeIndexOffset + nodeIndex - // Extract info for this node at the current level. - val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth) - val node = - new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats)) - logDebug("Node = " + node) - - if (!isLeaf) { - internalNodeCount += 1 - } - if (level == 0) { - newTopNode = node - } else { - // Set parent. - val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode) - if (Node.isLeftChild(globalNodeIndex)) { - parentNode.leftNode = Some(node) - } else { - parentNode.rightNode = Some(node) + // Iterate over all nodes in this group. + nodesForGroup.foreach { case (treeIndex, nodesForTree) => + nodesForTree.foreach { node => + val nodeIndex = node.id + val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val (split: Split, stats: InformationGainStats, predict: Predict) = + binsToBestSplit(binAggregates, aggNodeIndex, splits, featuresForNode) + logDebug("best split = " + split) + + // Extract info for this node. Create children if not leaf. + val isLeaf = (stats.gain <= 0) || (Node.indexToLevel(nodeIndex) == metadata.maxDepth) + assert(node.id == nodeIndex) + node.predict = predict.predict + node.isLeaf = isLeaf + node.stats = Some(stats) + logDebug("Node = " + node) + + if (!isLeaf) { + node.split = Some(split) + node.leftNode = Some(Node.emptyNode(Node.leftChildIndex(nodeIndex))) + node.rightNode = Some(Node.emptyNode(Node.rightChildIndex(nodeIndex))) + nodeQueue.enqueue((treeIndex, node.leftNode.get)) + nodeQueue.enqueue((treeIndex, node.rightNode.get)) + logDebug("leftChildIndex = " + node.leftNode.get.id + + ", impurity = " + stats.leftImpurity) + logDebug("rightChildIndex = " + node.rightNode.get.id + + ", impurity = " + stats.rightImpurity) } } - if (level < metadata.maxDepth) { - logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) + - ", impurity = " + stats.leftImpurity) - logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) + - ", impurity = " + stats.rightImpurity) - } - - nodeIndex += 1 } timer.stop("chooseSplits") - - val doneTraining = internalNodeCount == 0 - (newTopNode, doneTraining) } /** * Calculate the information gain for a given (feature, split) based upon left/right aggregates. * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for all splits + * @return information gain and statistics for split */ private def calculateGainForSplit( leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - level: Int, metadata: DecisionTreeMetadata): InformationGainStats = { val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count @@ -753,7 +616,7 @@ object DecisionTree extends Serializable with Logging { * Calculate predict value for current node, given stats of any split. * Note that this function is called only once for each node. * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a node + * @param rightImpurityCalculator right node aggregates for a split * @return predict value for current node */ private def calculatePredict( @@ -770,27 +633,33 @@ object DecisionTree extends Serializable with Logging { /** * Find the best split for a node. * @param binAggregates Bin statistics. - * @param nodeIndex Index for node to split in this (level, group). - * @return tuple for best split: (Split, information gain) + * @param nodeIndex Index into aggregates for node to split in this group. + * @return tuple for best split: (Split, information gain, prediction at node) */ private def binsToBestSplit( binAggregates: DTStatsAggregator, nodeIndex: Int, - level: Int, - metadata: DecisionTreeMetadata, - splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = { + splits: Array[Array[Split]], + featuresForNode: Option[Array[Int]]): (Split, InformationGainStats, Predict) = { + + val metadata: DecisionTreeMetadata = binAggregates.metadata // calculate predict only once var predict: Option[Predict] = None // For each (feature, split), calculate the gain, and select the best (feature, split). - val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex => + val (bestSplit, bestSplitStats) = Range(0, metadata.numFeaturesPerNode).map { featureIndexIdx => + val featureIndex = if (featuresForNode.nonEmpty) { + featuresForNode.get.apply(featureIndexIdx) + } else { + featureIndexIdx + } val numSplits = metadata.numSplits(featureIndex) if (metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. // Afterwards, binAggregates for a bin is the sum of aggregates for // that bin + all preceding bins. - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) var splitIndex = 0 while (splitIndex < numSplits) { binAggregates.mergeForNodeFeature(nodeFeatureOffset, splitIndex + 1, splitIndex) @@ -803,26 +672,26 @@ object DecisionTree extends Serializable with Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) (splitIdx, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (metadata.isUnordered(featureIndex)) { // Unordered categorical feature val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndex) + binAggregates.getLeftRightNodeFeatureOffsets(nodeIndex, featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { // Ordered categorical feature - val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndex) + val nodeFeatureOffset = binAggregates.getNodeFeatureOffset(nodeIndex, featureIndexIdx) val numBins = metadata.numBins(featureIndex) /* Each bin is one category (feature value). @@ -887,7 +756,7 @@ object DecisionTree extends Serializable with Logging { binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata) + val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, metadata) (splitIndex, gainStats) }.maxBy(_._2.gain) val categoriesForSplit = @@ -903,18 +772,6 @@ object DecisionTree extends Serializable with Logging { (bestSplit, bestSplitStats, predict.get) } - /** - * Get the number of values to be stored per node in the bin aggregates. - */ - private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = { - val totalBins = metadata.numBins.map(_.toLong).sum - if (metadata.isClassification) { - metadata.numClasses * totalBins - } else { - 3 * totalBins - } - } - /** * Returns splits and bins for decision tree calculation. * Continuous and categorical features are handled differently. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala new file mode 100644 index 0000000000000..7fa7725e79e46 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -0,0 +1,451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker} +import org.apache.spark.mllib.tree.impurity.Impurities +import org.apache.spark.mllib.tree.model._ +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + +/** + * :: Experimental :: + * A class which implements a random forest learning algorithm for classification and regression. + * It supports both continuous and categorical features. + * + * The settings for featureSubsetStrategy are based on the following references: + * - log2: tested in Breiman (2001) + * - sqrt: recommended by Breiman manual for random forests + * - The defaults of sqrt (classification) and onethird (regression) match the R randomForest + * package. + * @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]] + * @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for + * random forests]] + * + * @param strategy The configuration parameters for the random forest algorithm which specify + * the type of algorithm (classification, regression, etc.), feature type + * (continuous, categorical), depth of the tree, quantile calculation strategy, + * etc. + * @param numTrees If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param seed Random seed for bootstrapping and choosing feature subsets. + */ +@Experimental +private class RandomForest ( + private val strategy: Strategy, + private val numTrees: Int, + featureSubsetStrategy: String, + private val seed: Int) + extends Serializable with Logging { + + strategy.assertValid() + require(numTrees > 0, s"RandomForest requires numTrees > 0, but was given numTrees = $numTrees.") + require(RandomForest.supportedFeatureSubsetStrategies.contains(featureSubsetStrategy), + s"RandomForest given invalid featureSubsetStrategy: $featureSubsetStrategy." + + s" Supported values: ${RandomForest.supportedFeatureSubsetStrategies.mkString(", ")}.") + + /** + * Method to train a decision tree model over an RDD + * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] + * @return RandomForestModel that can be used for prediction + */ + def train(input: RDD[LabeledPoint]): RandomForestModel = { + + val timer = new TimeTracker() + + timer.start("total") + + timer.start("init") + + val retaggedInput = input.retag(classOf[LabeledPoint]) + val metadata = + DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy) + logDebug("algo = " + strategy.algo) + logDebug("numTrees = " + numTrees) + logDebug("seed = " + seed) + logDebug("maxBins = " + metadata.maxBins) + logDebug("featureSubsetStrategy = " + featureSubsetStrategy) + logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode) + + // Find the splits and the corresponding bins (interval between the splits) using a sample + // of the input data. + timer.start("findSplitsBins") + val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata) + timer.stop("findSplitsBins") + logDebug("numBins: feature: number of bins") + logDebug(Range(0, metadata.numFeatures).map { featureIndex => + s"\t$featureIndex\t${metadata.numBins(featureIndex)}" + }.mkString("\n")) + + // Bin feature values (TreePoint representation). + // Cache input RDD for speedup during multiple passes. + val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata) + val baggedInput = if (numTrees > 1) { + BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed) + } else { + BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + }.persist(StorageLevel.MEMORY_AND_DISK) + + // depth of the decision tree + val maxDepth = strategy.maxDepth + require(maxDepth <= 30, + s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.") + + // Max memory usage for aggregates + // TODO: Calculate memory usage more precisely. + val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L + logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.") + val maxMemoryPerNode = { + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + // Find numFeaturesPerNode largest bins to get an upper bound on memory usage. + Some(metadata.numBins.zipWithIndex.sortBy(- _._1) + .take(metadata.numFeaturesPerNode).map(_._2)) + } else { + None + } + RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + } + require(maxMemoryPerNode <= maxMemoryUsage, + s"RandomForest/DecisionTree given maxMemoryInMB = ${strategy.maxMemoryInMB}," + + " which is too small for the given features." + + s" Minimum value = ${maxMemoryPerNode / (1024L * 1024L)}") + + timer.stop("init") + + /* + * The main idea here is to perform group-wise training of the decision tree nodes thus + * reducing the passes over the data from (# nodes) to (# nodes / maxNumberOfNodesPerGroup). + * Each data sample is handled by a particular node (or it reaches a leaf and is not used + * in lower levels). + */ + + // FIFO queue of nodes to train: (treeIndex, node) + val nodeQueue = new mutable.Queue[(Int, Node)]() + + val rng = new scala.util.Random() + rng.setSeed(seed) + + // Allocate and queue root nodes. + val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1)) + Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex)))) + + while (nodeQueue.nonEmpty) { + // Collect some nodes to split, and choose features for each node (if subsampling). + // Each group of nodes may come from one or multiple trees, and at multiple levels. + val (nodesForGroup, treeToNodeToIndexInfo) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + // Sanity check (should never occur): + assert(nodesForGroup.size > 0, + s"RandomForest selected empty nodesForGroup. Error for unknown reason.") + + // Choose node splits, and enqueue new nodes as needed. + timer.start("findBestSplits") + DecisionTree.findBestSplits(baggedInput, + metadata, topNodes, nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) + timer.stop("findBestSplits") + } + + timer.stop("total") + + logInfo("Internal timing for DecisionTree:") + logInfo(s"$timer") + + val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) + RandomForestModel.build(trees) + } + +} + +object RandomForest extends Serializable with Logging { + + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param strategy Parameters for training each tree in the forest. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int): RandomForestModel = { + require(strategy.algo == Classification, + s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}") + val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) + rf.train(input) + } + + /** + * Method to train a decision tree model for binary or multiclass classification. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels should take values {0, 1, ..., numClasses-1}. + * @param numClassesForClassification number of classes for classification. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param impurity Criterion used for information gain calculation. + * Supported values: "gini" (recommended) or "entropy". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainClassifier( + input: RDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()): RandomForestModel = { + val impurityType = Impurities.fromString(impurity) + val strategy = new Strategy(Classification, impurityType, maxDepth, + numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo) + trainClassifier(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainClassifier]] + */ + def trainClassifier( + input: JavaRDD[LabeledPoint], + numClassesForClassification: Int, + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + trainClassifier(input.rdd, numClassesForClassification, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param strategy Parameters for training each tree in the forest. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String, + seed: Int): RandomForestModel = { + require(strategy.algo == Regression, + s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}") + val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed) + rf.train(input) + } + + /** + * Method to train a decision tree model for regression. + * + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * Labels are real numbers. + * @param categoricalFeaturesInfo Map storing arity of categorical features. + * E.g., an entry (n -> k) indicates that feature n is categorical + * with k categories indexed from 0: {0, 1, ..., k-1}. + * @param numTrees Number of trees in the random forest. + * @param featureSubsetStrategy Number of features to consider for splits at each node. + * Supported: "auto" (default), "all", "sqrt", "log2", "onethird". + * If "auto" is set, this parameter is set based on numTrees: + * if numTrees == 1, set to "all"; + * if numTrees > 1 (forest) set to "sqrt" for classification and + * to "onethird" for regression. + * @param impurity Criterion used for information gain calculation. + * Supported values: "variance". + * @param maxDepth Maximum depth of the tree. + * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes. + * (suggested value: 4) + * @param maxBins maximum number of bins used for splitting features + * (suggested value: 100) + * @param seed Random seed for bootstrapping and choosing feature subsets. + * @return RandomForestModel that can be used for prediction + */ + def trainRegressor( + input: RDD[LabeledPoint], + categoricalFeaturesInfo: Map[Int, Int], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int = Utils.random.nextInt()): RandomForestModel = { + val impurityType = Impurities.fromString(impurity) + val strategy = new Strategy(Regression, impurityType, maxDepth, + 0, maxBins, Sort, categoricalFeaturesInfo) + trainRegressor(input, strategy, numTrees, featureSubsetStrategy, seed) + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.RandomForest$#trainRegressor]] + */ + def trainRegressor( + input: JavaRDD[LabeledPoint], + categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer], + numTrees: Int, + featureSubsetStrategy: String, + impurity: String, + maxDepth: Int, + maxBins: Int, + seed: Int): RandomForestModel = { + trainRegressor(input.rdd, + categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap, + numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed) + } + + /** + * List of supported feature subset sampling strategies. + */ + val supportedFeatureSubsetStrategies: Array[String] = + Array("auto", "all", "sqrt", "log2", "onethird") + + private[tree] class NodeIndexInfo( + val nodeIndexInGroup: Int, + val featureSubset: Option[Array[Int]]) extends Serializable + + /** + * Pull nodes off of the queue, and collect a group of nodes to be split on this iteration. + * This tracks the memory usage for aggregates and stops adding nodes when too much memory + * will be needed; this allows an adaptive number of nodes since different nodes may require + * different amounts of memory (if featureSubsetStrategy is not "all"). + * + * @param nodeQueue Queue of nodes to split. + * @param maxMemoryUsage Bound on size of aggregate statistics. + * @return (nodesForGroup, treeToNodeToIndexInfo). + * nodesForGroup holds the nodes to split: treeIndex --> nodes in tree. + * treeToNodeToIndexInfo holds indices selected features for each node: + * treeIndex --> (global) node index --> (node index in group, feature indices). + * The (global) node index is the index in the tree; the node index in group is the + * index in [0, numNodesInGroup) of the node in this group. + * The feature indices are None if not subsampling features. + */ + private[tree] def selectNodesToSplit( + nodeQueue: mutable.Queue[(Int, Node)], + maxMemoryUsage: Long, + metadata: DecisionTreeMetadata, + rng: scala.util.Random): (Map[Int, Array[Node]], Map[Int, Map[Int, NodeIndexInfo]]) = { + // Collect some nodes to split: + // nodesForGroup(treeIndex) = nodes to split + val mutableNodesForGroup = new mutable.HashMap[Int, mutable.ArrayBuffer[Node]]() + val mutableTreeToNodeToIndexInfo = + new mutable.HashMap[Int, mutable.HashMap[Int, NodeIndexInfo]]() + var memUsage: Long = 0L + var numNodesInGroup = 0 + while (nodeQueue.nonEmpty && memUsage < maxMemoryUsage) { + val (treeIndex, node) = nodeQueue.head + // Choose subset of features for node (if subsampling). + val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) { + // TODO: Use more efficient subsampling? (use selection-and-rejection or reservoir) + Some(rng.shuffle(Range(0, metadata.numFeatures).toList) + .take(metadata.numFeaturesPerNode).toArray) + } else { + None + } + // Check if enough memory remains to add this node to the group. + val nodeMemUsage = RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L + if (memUsage + nodeMemUsage <= maxMemoryUsage) { + nodeQueue.dequeue() + mutableNodesForGroup.getOrElseUpdate(treeIndex, new mutable.ArrayBuffer[Node]()) += node + mutableTreeToNodeToIndexInfo + .getOrElseUpdate(treeIndex, new mutable.HashMap[Int, NodeIndexInfo]())(node.id) + = new NodeIndexInfo(numNodesInGroup, featureSubset) + } + numNodesInGroup += 1 + memUsage += nodeMemUsage + } + // Convert mutable maps to immutable ones. + val nodesForGroup: Map[Int, Array[Node]] = mutableNodesForGroup.mapValues(_.toArray).toMap + val treeToNodeToIndexInfo = mutableTreeToNodeToIndexInfo.mapValues(_.toMap).toMap + (nodesForGroup, treeToNodeToIndexInfo) + } + + /** + * Get the number of values to be stored for this node in the bin aggregates. + * @param featureSubset Indices of features which may be split at this node. + * If None, then use all features. + */ + private[tree] def aggregateSizeForNode( + metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]): Long = { + val totalBins = if (featureSubset.nonEmpty) { + featureSubset.get.map(featureIndex => metadata.numBins(featureIndex).toLong).sum + } else { + metadata.numBins.map(_.toLong).sum + } + if (metadata.isClassification) { + metadata.numClasses * totalBins + } else { + 3 * totalBins + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala new file mode 100644 index 0000000000000..937c8a2ac5836 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import cern.jet.random.Poisson +import cern.jet.random.engine.DRand + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + +/** + * Internal representation of a datapoint which belongs to several subsamples of the same dataset, + * particularly for bagging (e.g., for random forests). + * + * This holds one instance, as well as an array of weights which represent the (weighted) + * number of times which this instance appears in each subsample. + * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that + * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. + * + * @param datum Data instance + * @param subsampleWeights Weight of this instance in each subsampled dataset. + * + * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted + * dataset support, update. (We store subsampleWeights as Double for this future extension.) + */ +private[tree] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) + extends Serializable + +private[tree] object BaggedPoint { + + /** + * Convert an input dataset into its BaggedPoint representation, + * choosing subsample counts for each instance. + * Each subsample has the same number of instances as the original dataset, + * and is created by subsampling with replacement. + * @param input Input dataset. + * @param numSubsamples Number of subsamples of this RDD to take. + * @param seed Random seed. + * @return BaggedPoint dataset representation + */ + def convertToBaggedRDD[Datum]( + input: RDD[Datum], + numSubsamples: Int, + seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // TODO: Support different sampling rates, and sampling without replacement. + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1)) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + subsampleWeights(subsampleIndex) = poisson.nextInt() + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { + input.map(datum => new BaggedPoint(datum, Array(1.0))) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 61a94246711bf..d49df7a016375 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -17,16 +17,17 @@ package org.apache.spark.mllib.tree.impl +import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo import org.apache.spark.mllib.tree.impurity._ /** * DecisionTree statistics aggregator. * This holds a flat array of statistics for a set of (nodes, features, bins) * and helps with indexing. + * This class is abstract to support learning with and without feature subsampling. */ -private[tree] class DTStatsAggregator( - val metadata: DecisionTreeMetadata, - val numNodes: Int) extends Serializable { +private[tree] abstract class DTStatsAggregator( + val metadata: DecisionTreeMetadata) extends Serializable { /** * [[ImpurityAggregator]] instance specifying the impurity type. @@ -43,49 +44,21 @@ private[tree] class DTStatsAggregator( */ val statsSize: Int = impurityAggregator.statsSize - val numFeatures: Int = metadata.numFeatures - - /** - * Number of bins for each feature. This is indexed by the feature index. - */ - val numBins: Array[Int] = metadata.numBins - - /** - * Number of splits for the given feature. - */ - def numSplits(featureIndex: Int): Int = metadata.numSplits(featureIndex) - /** * Indicator for each feature of whether that feature is an unordered feature. * TODO: Is Array[Boolean] any faster? */ def isUnordered(featureIndex: Int): Boolean = metadata.isUnordered(featureIndex) - /** - * Offset for each feature for calculating indices into the [[allStats]] array. - */ - private val featureOffsets: Array[Int] = { - numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - - /** - * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. - */ - private val nodeStride: Int = featureOffsets.last - /** * Total number of elements stored in this aggregator. */ - val allStatsSize: Int = numNodes * nodeStride + def allStatsSize: Int /** - * Flat array of elements. - * Index for start of stats for a (node, feature, bin) is: - * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, the left child stats have binIndex in [0, numBins(featureIndex)) - * and the right child stats in [numBins(featureIndex), 2 * numBins(featureIndex)) + * Get flat array of elements stored in this aggregator. */ - val allStats: Array[Double] = new Array[Double](allStatsSize) + protected def allStats: Array[Double] /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). @@ -102,36 +75,39 @@ private[tree] class DTStatsAggregator( /** * Update the stats for a given (node, feature, bin) for ordered features, using the given label. */ - def update(nodeIndex: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { - val i = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label) + def update( + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + val i = getNodeFeatureOffset(nodeIndex, featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) } /** * Pre-compute node offset for use with [[nodeUpdate]]. */ - def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride + def getNodeOffset(nodeIndex: Int): Int /** * Faster version of [[update]]. * Update the stats for a given (node, feature, bin) for ordered features, using the given label. * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. */ - def nodeUpdate(nodeOffset: Int, featureIndex: Int, binIndex: Int, label: Double): Unit = { - val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label) - } + def nodeUpdate( + nodeOffset: Int, + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit /** * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. * For ordered features only. */ - def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { - require(!isUnordered(featureIndex), - s"DTStatsAggregator.getNodeFeatureOffset is for ordered features only, but was called" + - s" for unordered feature $featureIndex.") - nodeIndex * nodeStride + featureOffsets(featureIndex) - } + def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int /** * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. @@ -140,9 +116,9 @@ private[tree] class DTStatsAggregator( def getLeftRightNodeFeatureOffsets(nodeIndex: Int, featureIndex: Int): (Int, Int) = { require(isUnordered(featureIndex), s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," + - s" but was called for ordered feature $featureIndex.") - val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex) - (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) + s" but was called for ordered feature $featureIndex.") + val baseOffset = getNodeFeatureOffset(nodeIndex, featureIndex) + (baseOffset, baseOffset + (metadata.numBins(featureIndex) >> 1) * statsSize) } /** @@ -154,8 +130,13 @@ private[tree] class DTStatsAggregator( * (node, feature, left/right child) offset from * [[getLeftRightNodeFeatureOffsets]]. */ - def nodeFeatureUpdate(nodeFeatureOffset: Int, binIndex: Int, label: Double): Unit = { - impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label) + def nodeFeatureUpdate( + nodeFeatureOffset: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + impurityAggregator.update(allStats, nodeFeatureOffset + binIndex * statsSize, label, + instanceWeight) } /** @@ -189,7 +170,139 @@ private[tree] class DTStatsAggregator( } this } +} + +/** + * DecisionTree statistics aggregator. + * This holds a flat array of statistics for a set of (nodes, features, bins) + * and helps with indexing. + * + * This instance of [[DTStatsAggregator]] is used when not subsampling features. + * + * @param numNodes Number of nodes to collect statistics for. + */ +private[tree] class DTStatsAggregatorFixedFeatures( + metadata: DecisionTreeMetadata, + numNodes: Int) extends DTStatsAggregator(metadata) { + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + * Mapping: featureIndex --> offset + */ + private val featureOffsets: Array[Int] = { + metadata.numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) + } + + /** + * Number of elements for each node, corresponding to stride between nodes in [[allStats]]. + */ + private val nodeStride: Int = featureOffsets.last + override val allStatsSize: Int = numNodes * nodeStride + + /** + * Flat array of elements. + * Index for start of stats for a (node, feature, bin) is: + * index = nodeIndex * nodeStride + featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, the left child stats precede the right child stats + * in the binIndex order. + */ + override protected val allStats: Array[Double] = new Array[Double](allStatsSize) + + override def getNodeOffset(nodeIndex: Int): Int = nodeIndex * nodeStride + + override def nodeUpdate( + nodeOffset: Int, + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + val i = nodeOffset + featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) + } + + override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { + nodeIndex * nodeStride + featureOffsets(featureIndex) + } +} + +/** + * DecisionTree statistics aggregator. + * This holds a flat array of statistics for a set of (nodes, features, bins) + * and helps with indexing. + * + * This instance of [[DTStatsAggregator]] is used when subsampling features. + * + * @param treeToNodeToIndexInfo Mapping: treeIndex --> nodeIndex --> nodeIndexInfo, + * where nodeIndexInfo stores the index in the group and the + * feature subsets (if using feature subsets). + */ +private[tree] class DTStatsAggregatorSubsampledFeatures( + metadata: DecisionTreeMetadata, + treeToNodeToIndexInfo: Map[Int, Map[Int, NodeIndexInfo]]) extends DTStatsAggregator(metadata) { + + /** + * For each node, offset for each feature for calculating indices into the [[allStats]] array. + * Mapping: nodeIndex --> featureIndex --> offset + */ + private val featureOffsets: Array[Array[Int]] = { + val numNodes: Int = treeToNodeToIndexInfo.values.map(_.size).sum + val offsets = new Array[Array[Int]](numNodes) + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeToIndexInfo) => + nodeToIndexInfo.foreach { case (globalNodeIndex, nodeInfo) => + offsets(nodeInfo.nodeIndexInGroup) = nodeInfo.featureSubset.get.map(metadata.numBins(_)) + .scanLeft(0)((total, nBins) => total + statsSize * nBins) + } + } + offsets + } + + /** + * For each node, offset for each feature for calculating indices into the [[allStats]] array. + */ + protected val nodeOffsets: Array[Int] = featureOffsets.map(_.last).scanLeft(0)(_ + _) + + override val allStatsSize: Int = nodeOffsets.last + + /** + * Flat array of elements. + * Index for start of stats for a (node, feature, bin) is: + * index = nodeOffsets(nodeIndex) + featureOffsets(featureIndex) + binIndex * statsSize + * Note: For unordered features, the left child stats precede the right child stats + * in the binIndex order. + */ + override protected val allStats: Array[Double] = new Array[Double](allStatsSize) + + override def getNodeOffset(nodeIndex: Int): Int = nodeOffsets(nodeIndex) + + /** + * Faster version of [[update]]. + * Update the stats for a given (node, feature, bin) for ordered features, using the given label. + * @param nodeOffset Pre-computed node offset from [[getNodeOffset]]. + * @param featureIndex Index of feature in featuresForNodes(nodeIndex). + * Note: This is NOT the original feature index. + */ + override def nodeUpdate( + nodeOffset: Int, + nodeIndex: Int, + featureIndex: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + val i = nodeOffset + featureOffsets(nodeIndex)(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) + } + + /** + * Pre-compute (node, feature) offset for use with [[nodeFeatureUpdate]]. + * For ordered features only. + * @param featureIndex Index of feature in featuresForNodes(nodeIndex). + * Note: This is NOT the original feature index. + */ + override def getNodeFeatureOffset(nodeIndex: Int, featureIndex: Int): Int = { + nodeOffsets(nodeIndex) + featureOffsets(nodeIndex)(featureIndex) + } } private[tree] object DTStatsAggregator extends Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index b6d49e5555b1a..212dce25236e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -48,7 +48,9 @@ private[tree] class DecisionTreeMetadata( val quantileStrategy: QuantileStrategy, val maxDepth: Int, val minInstancesPerNode: Int, - val minInfoGain: Double) extends Serializable { + val minInfoGain: Double, + val numTrees: Int, + val numFeaturesPerNode: Int) extends Serializable { def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) @@ -73,6 +75,11 @@ private[tree] class DecisionTreeMetadata( numBins(featureIndex) - 1 } + /** + * Indicates if feature subsampling is being used. + */ + def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode + } private[tree] object DecisionTreeMetadata { @@ -82,7 +89,11 @@ private[tree] object DecisionTreeMetadata { * This computes which categorical features will be ordered vs. unordered, * as well as the number of splits and bins for each feature. */ - def buildMetadata(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeMetadata = { + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String): DecisionTreeMetadata = { val numFeatures = input.take(1)(0).features.size val numExamples = input.count() @@ -128,13 +139,43 @@ private[tree] object DecisionTreeMetadata { } } + // Set number of features to use per node (for random forests). + val _featureSubsetStrategy = featureSubsetStrategy match { + case "auto" => + if (numTrees == 1) { + "all" + } else { + if (strategy.algo == Classification) { + "sqrt" + } else { + "onethird" + } + } + case _ => featureSubsetStrategy + } + val numFeaturesPerNode: Int = _featureSubsetStrategy match { + case "all" => numFeatures + case "sqrt" => math.sqrt(numFeatures).ceil.toInt + case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) + case "onethird" => (numFeatures / 3.0).ceil.toInt + } + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain) + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) } /** + * Version of [[buildMetadata()]] for DecisionTree. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy): DecisionTreeMetadata = { + buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") + } + + /** * Given the arity of a categorical feature (arity = number of categories), * return the number of bins for the feature if it is to be treated as an unordered feature. * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 1c8afc2d0f4bc..0e02345aa3774 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -89,12 +89,12 @@ private[tree] class EntropyAggregator(numClasses: Int) * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"EntropyAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } - allStats(offset + label.toInt) += 1 + allStats(offset + label.toInt) += instanceWeight } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 5cfdf345d163c..7c83cd48e16a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -85,12 +85,12 @@ private[tree] class GiniAggregator(numClasses: Int) * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit = { + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { if (label >= statsSize) { throw new IllegalArgumentException(s"GiniAggregator given label $label" + s" but requires label < numClasses (= $statsSize).") } - allStats(offset + label.toInt) += 1 + allStats(offset + label.toInt) += instanceWeight } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 5a047d6cb5480..60e2ab2bb829e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -78,7 +78,7 @@ private[tree] abstract class ImpurityAggregator(val statsSize: Int) extends Seri * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit /** * Get an [[ImpurityCalculator]] for a (node, feature, bin). diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index e9ccecb1b8067..df9eafa5da16a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -75,10 +75,10 @@ private[tree] class VarianceAggregator() * @param allStats Flat stats array, with stats for this (node, feature, bin) contiguous. * @param offset Start index of stats for this (node, feature, bin). */ - def update(allStats: Array[Double], offset: Int, label: Double): Unit = { - allStats(offset) += 1 - allStats(offset + 1) += label - allStats(offset + 2) += label * label + def update(allStats: Array[Double], offset: Int, label: Double, instanceWeight: Double): Unit = { + allStats(offset) += instanceWeight + allStats(offset + 1) += instanceWeight * label + allStats(offset + 2) += instanceWeight * label * label } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index 5f0095d23c7ed..56c3e25d9285f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -41,12 +41,12 @@ import org.apache.spark.mllib.linalg.Vector @DeveloperApi class Node ( val id: Int, - val predict: Double, - val isLeaf: Boolean, - val split: Option[Split], + var predict: Double, + var isLeaf: Boolean, + var split: Option[Split], var leftNode: Option[Node], var rightNode: Option[Node], - val stats: Option[InformationGainStats]) extends Serializable with Logging { + var stats: Option[InformationGainStats]) extends Serializable with Logging { override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " + "split = " + split + ", stats = " + stats @@ -167,6 +167,11 @@ class Node ( private[tree] object Node { + /** + * Return a node with the given node id (but nothing else set). + */ + def emptyNode(nodeIndex: Int): Node = new Node(nodeIndex, 0, false, None, None, None, None) + /** * Return the index of the left child of this node. */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala new file mode 100644 index 0000000000000..538c0e233202a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.model + +import scala.collection.mutable + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Random forest model for classification or regression. + * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make + * aggregate predictions. + * @param trees Trees which make up this forest. This cannot be empty. + * @param algo algorithm type -- classification or regression + */ +@Experimental +class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable { + + require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") + + /** + * Predict values for a single data point. + * + * @param features array representing a single data point + * @return Double prediction from the trained model + */ + def predict(features: Vector): Double = { + algo match { + case Classification => + val predictionToCount = new mutable.HashMap[Int, Int]() + trees.foreach { tree => + val prediction = tree.predict(features).toInt + predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1 + } + predictionToCount.maxBy(_._2)._1 + case Regression => + trees.map(_.predict(features)).sum / trees.size + } + } + + /** + * Predict values for the given data set. + * + * @param features RDD representing data points to be predicted + * @return RDD[Double] where each entry contains the corresponding prediction + */ + def predict(features: RDD[Vector]): RDD[Double] = { + features.map(x => predict(x)) + } + + /** + * Get number of trees in forest. + */ + def numTrees: Int = trees.size + + /** + * Print full model. + */ + override def toString: String = { + val header = algo match { + case Classification => + s"RandomForestModel classifier with $numTrees trees\n" + case Regression => + s"RandomForestModel regressor with $numTrees trees\n" + case _ => throw new IllegalArgumentException( + s"RandomForestModel given unknown algo parameter: $algo.") + } + header + trees.zipWithIndex.map { case (tree, treeIndex) => + s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4) + }.fold("")(_ + _) + } + +} + +private[tree] object RandomForestModel { + + def build(trees: Array[DecisionTreeModel]): RandomForestModel = { + require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.") + val algo: Algo = trees(0).algo + require(trees.forall(_.algo == algo), + "RandomForestModel cannot combine trees which have different output types" + + " (classification/regression).") + new RandomForestModel(trees, algo) + } + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 2b2e579b992f6..a48ed71a1c5fc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ +import scala.collection.mutable import org.scalatest.FunSuite @@ -26,39 +27,13 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node} import org.apache.spark.mllib.util.LocalSparkContext class DecisionTreeSuite extends FunSuite with LocalSparkContext { - def validateClassifier( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredAccuracy: Double) { - val predictions = input.map(x => model.predict(x.features)) - val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => - prediction != expected.label - } - val accuracy = (input.length - numOffPredictions).toDouble / input.length - assert(accuracy >= requiredAccuracy, - s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") - } - - def validateRegressor( - model: DecisionTreeModel, - input: Seq[LabeledPoint], - requiredMSE: Double) { - val predictions = input.map(x => model.predict(x.features)) - val squaredError = predictions.zip(input).map { case (prediction, expected) => - val err = prediction - expected.label - err * err - }.sum - val mse = squaredError / input.length - assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") - } - test("Binary classification with continuous features: split and bin calculation") { val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1() assert(arr.length === 1000) @@ -233,7 +208,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 100, maxBins = 100, categoricalFeaturesInfo = Map(0 -> 10, 1-> 10)) - // 2^10 - 1 > 100, so categorical features will be ordered + // 2^(10-1) - 1 > 100, so categorical features will be ordered val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) assert(!metadata.isUnordered(featureIndex = 0)) @@ -269,9 +244,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(splits(0).length === 0) assert(bins(0).length === 0) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode: Node, doneTraining: Boolean) = - DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.categories === List(1.0)) @@ -299,10 +272,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.categories.length === 1) @@ -331,7 +301,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) - validateRegressor(model, arr, 0.0) + DecisionTreeSuite.validateRegressor(model, arr, 0.0) assert(model.numNodes === 3) assert(model.depth === 1) } @@ -352,12 +322,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -381,12 +346,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -411,12 +371,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -441,12 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(bins.length === 2) assert(bins(0).length === 100) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) - - val split = rootNode.split.get - assert(split.feature === 0) + val rootNode = DecisionTree.train(rdd, strategy).topNode val stats = rootNode.stats.get assert(stats.gain === 0) @@ -471,25 +421,39 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1, numClassesForClassification = 2, maxBins = 100) val modelOneNode = DecisionTree.train(rdd, strategyOneNode) - val rootNodeCopy1 = modelOneNode.topNode.deepCopy() - val rootNodeCopy2 = modelOneNode.topNode.deepCopy() + val rootNode1 = modelOneNode.topNode.deepCopy() + val rootNode2 = modelOneNode.topNode.deepCopy() + assert(rootNode1.leftNode.nonEmpty) + assert(rootNode1.rightNode.nonEmpty) - // Single group second level tree construction. val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, - rootNodeCopy1, splits, bins, 10) - assert(rootNode.leftNode.nonEmpty) - assert(rootNode.rightNode.nonEmpty) + val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput) + + // Single group second level tree construction. + val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get))) + val treeToNodeToIndexInfo = Map((0, Map( + (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)), + (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None))))) + val nodeQueue = new mutable.Queue[(Int, Node)]() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1), + nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue) val children1 = new Array[Node](2) - children1(0) = rootNode.leftNode.get - children1(1) = rootNode.rightNode.get - - // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second - // level tree construction. - val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1, - rootNodeCopy2, splits, bins, 0) - assert(rootNode2.leftNode.nonEmpty) - assert(rootNode2.rightNode.nonEmpty) + children1(0) = rootNode1.leftNode.get + children1(1) = rootNode1.rightNode.get + + // Train one second-level node at a time. + val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get))) + val treeToNodeToIndexInfoA = Map((0, Map( + (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue) + val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get))) + val treeToNodeToIndexInfoB = Map((0, Map( + (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None))))) + nodeQueue.clear() + DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2), + nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue) val children2 = new Array[Node](2) children2(0) = rootNode2.leftNode.get children2(1) = rootNode2.rightNode.get @@ -521,10 +485,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(metadata.isUnordered(featureIndex = 0)) assert(metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.feature === 0) @@ -544,7 +505,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) } @@ -561,7 +522,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { numClassesForClassification = 2) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) assert(model.topNode.split.get.feature === 1) @@ -581,14 +542,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(metadata.isUnordered(featureIndex = 1)) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 1.0) + DecisionTreeSuite.validateClassifier(model, arr, 1.0) assert(model.numNodes === 3) assert(model.depth === 1) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = model.topNode val split = rootNode.split.get assert(split.feature === 0) @@ -610,12 +568,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 0.9) + DecisionTreeSuite.validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = model.topNode val split = rootNode.split.get assert(split.feature === 1) @@ -635,12 +590,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(metadata.isUnordered(featureIndex = 0)) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 0.9) + DecisionTreeSuite.validateClassifier(model, arr, 0.9) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = model.topNode val split = rootNode.split.get assert(split.feature === 1) @@ -660,10 +612,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(!metadata.isUnordered(featureIndex = 0)) assert(!metadata.isUnordered(featureIndex = 1)) - val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata) - val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get assert(split.feature === 0) @@ -682,7 +631,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(strategy.isMulticlassClassification) val model = DecisionTree.train(rdd, strategy) - validateClassifier(model, arr, 0.6) + DecisionTreeSuite.validateClassifier(model, arr, 0.6) } test("split must satisfy min instances per node requirements") { @@ -691,24 +640,20 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))) arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2) - val model = DecisionTree.train(input, strategy) + val model = DecisionTree.train(rdd, strategy) assert(model.topNode.isLeaf) assert(model.topNode.predict == 0.0) - val predicts = input.map(p => model.predict(p.features)).collect() + val predicts = rdd.map(p => model.predict(p.features)).collect() predicts.foreach { predict => assert(predict == 0.0) } - // test for findBestSplits when no valid split can be found - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + // test when no valid split can be found + val rootNode = model.topNode val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) @@ -723,15 +668,12 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)) - val input = sc.parallelize(arr) + val rdd = sc.parallelize(arr) val strategy = new Strategy(algo = Classification, impurity = Gini, maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2), numClassesForClassification = 2, minInstancesPerNode = 2) - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + + val rootNode = DecisionTree.train(rdd, strategy).topNode val split = rootNode.split.get val gain = rootNode.stats.get @@ -757,12 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { assert(predict == 0.0) } - // test for findBestSplits when no valid split can be found - val metadata = DecisionTreeMetadata.buildMetadata(input, strategy) - val (splits, bins) = DecisionTree.findSplitsBins(input, metadata) - val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata) - val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0, - null, splits, bins, 10) + // test when no valid split can be found + val rootNode = model.topNode val gain = rootNode.stats.get assert(gain == InformationGainStats.invalidInformationGainStats) @@ -771,6 +709,32 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { object DecisionTreeSuite { + def validateClassifier( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } + + def validateRegressor( + model: DecisionTreeModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + val err = prediction - expected.label + err * err + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + } + def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = { val arr = new Array[LabeledPoint](1000) for (i <- 0 until 1000) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala new file mode 100644 index 0000000000000..30669fcd1c75b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} +import org.apache.spark.mllib.tree.impurity.{Gini, Variance} +import org.apache.spark.mllib.tree.model.{Node, RandomForestModel} +import org.apache.spark.mllib.util.LocalSparkContext +import org.apache.spark.util.StatCounter + +/** + * Test suite for [[RandomForest]]. + */ +class RandomForestSuite extends FunSuite with LocalSparkContext { + + test("BaggedPoint RDD: without subsampling") { + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) + val rdd = sc.parallelize(arr) + val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd) + baggedRDD.collect().foreach { baggedPoint => + assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + } + } + + test("BaggedPoint RDD: with subsampling") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 1.0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("Binary classification with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + val numTrees = 1 + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + + val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, + featureSubsetStrategy = "auto", seed = 123) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) + + val dt = DecisionTree.train(rdd, strategy) + + RandomForestSuite.validateClassifier(rf, arr, 0.9) + DecisionTreeSuite.validateClassifier(dt, arr, 0.9) + + // Make sure trees are the same. + assert(rfTree.toString == dt.toString) + } + + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + val numTrees = 1 + + val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + + val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, + featureSubsetStrategy = "auto", seed = 123) + assert(rf.trees.size === 1) + val rfTree = rf.trees(0) + + val dt = DecisionTree.train(rdd, strategy) + + RandomForestSuite.validateRegressor(rf, arr, 0.01) + DecisionTreeSuite.validateRegressor(dt, arr, 0.01) + + // Make sure trees are the same. + assert(rfTree.toString == dt.toString) + } + + test("Binary classification with continuous features: subsampling features") { + val numFeatures = 50 + val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures) + val rdd = sc.parallelize(arr) + val categoricalFeaturesInfo = Map.empty[Int, Int] + + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + + // Select feature subset for top nodes. Return true if OK. + def checkFeatureSubsetStrategy( + numTrees: Int, + featureSubsetStrategy: String, + numFeaturesPerNode: Int): Unit = { + val seeds = Array(123, 5354, 230, 349867, 23987) + val maxMemoryUsage: Long = 128 * 1024L * 1024L + val metadata = + DecisionTreeMetadata.buildMetadata(rdd, strategy, numTrees, featureSubsetStrategy) + seeds.foreach { seed => + val failString = s"Failed on test with:" + + s"numTrees=$numTrees, featureSubsetStrategy=$featureSubsetStrategy," + + s" numFeaturesPerNode=$numFeaturesPerNode, seed=$seed" + val nodeQueue = new mutable.Queue[(Int, Node)]() + val topNodes: Array[Node] = new Array[Node](numTrees) + Range(0, numTrees).foreach { treeIndex => + topNodes(treeIndex) = Node.emptyNode(nodeIndex = 1) + nodeQueue.enqueue((treeIndex, topNodes(treeIndex))) + } + val rng = new scala.util.Random(seed = seed) + val (nodesForGroup: Map[Int, Array[Node]], + treeToNodeToIndexInfo: Map[Int, Map[Int, RandomForest.NodeIndexInfo]]) = + RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng) + + assert(nodesForGroup.size === numTrees, failString) + assert(nodesForGroup.values.forall(_.size == 1), failString) // 1 node per tree + if (numFeaturesPerNode == numFeatures) { + // featureSubset values should all be None + assert(treeToNodeToIndexInfo.values.forall(_.values.forall(_.featureSubset.isEmpty)), + failString) + } else { + // Check number of features. + assert(treeToNodeToIndexInfo.values.forall(_.values.forall( + _.featureSubset.get.size === numFeaturesPerNode)), failString) + } + } + } + + checkFeatureSubsetStrategy(numTrees = 1, "auto", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 1, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 1, "onethird", (numFeatures / 3.0).ceil.toInt) + + checkFeatureSubsetStrategy(numTrees = 2, "all", numFeatures) + checkFeatureSubsetStrategy(numTrees = 2, "auto", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "sqrt", math.sqrt(numFeatures).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "log2", + (math.log(numFeatures) / math.log(2)).ceil.toInt) + checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) + } + +} + +object RandomForestSuite { + + /** + * Aggregates all values in data, and tests whether the empirical mean and stddev are within + * epsilon of the expected values. + * @param data Every element of the data should be an i.i.d. sample from some distribution. + */ + def testRandomArrays( + data: Array[Array[Double]], + numCols: Int, + expectedMean: Double, + expectedStddev: Double, + epsilon: Double) { + val values = new mutable.ArrayBuffer[Double]() + data.foreach { row => + assert(row.size == numCols) + values ++= row + } + val stats = new StatCounter(values) + assert(math.abs(stats.mean - expectedMean) < epsilon) + assert(math.abs(stats.stdev - expectedStddev) < epsilon) + } + + def validateClassifier( + model: RandomForestModel, + input: Seq[LabeledPoint], + requiredAccuracy: Double) { + val predictions = input.map(x => model.predict(x.features)) + val numOffPredictions = predictions.zip(input).count { case (prediction, expected) => + prediction != expected.label + } + val accuracy = (input.length - numOffPredictions).toDouble / input.length + assert(accuracy >= requiredAccuracy, + s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.") + } + + def validateRegressor( + model: RandomForestModel, + input: Seq[LabeledPoint], + requiredMSE: Double) { + val predictions = input.map(x => model.predict(x.features)) + val squaredError = predictions.zip(input).map { case (prediction, expected) => + val err = prediction - expected.label + err * err + }.sum + val mse = squaredError / input.length + assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.") + } + + def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = { + val numInstances = 1000 + val arr = new Array[LabeledPoint](numInstances) + for (i <- 0 until numInstances) { + val label = if (i < numInstances / 10) { + 0.0 + } else if (i < numInstances / 2) { + 1.0 + } else if (i < numInstances * 0.9) { + 0.0 + } else { + 1.0 + } + val features = Array.fill[Double](numFeatures)(i.toDouble) + arr(i) = new LabeledPoint(label, Vectors.dense(features)) + } + arr + } + +} From 1651cc117d73f0af6ec9f55b0c6c9b2bd565906c Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Sun, 28 Sep 2014 21:55:09 -0700 Subject: [PATCH 38/50] [EC2] Cleanup Python parens and disk dict Minor fixes: * Remove unnecessary parens (Python style) * Sort `disks_by_instance` dict and remove duplicate `t1.micro` key Author: Nicholas Chammas Closes #2571 from nchammas/ec2-polish and squashes the following commits: 9d203d5 [Nicholas Chammas] paren and dict cleanup --- ec2/spark_ec2.py | 60 ++++++++++++++++++++++++------------------------ 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 7f2cd7d94de39..5776d0b519309 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -508,7 +508,7 @@ def tag_instance(instance, name): break except: print "Failed attempt %i of 5 to tag %s" % ((i + 1), name) - if (i == 5): + if i == 5: raise "Error - failed max attempts to add name tag" time.sleep(5) @@ -530,7 +530,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): for res in reservations: active = [i for i in res.instances if is_active(i)] for instance in active: - if (instance.tags.get(u'Name') is None): + if instance.tags.get(u'Name') is None: tag_instance(instance, name) # Now proceed to detect master and slaves instances. reservations = conn.get_all_instances() @@ -545,7 +545,7 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): elif name.startswith(cluster_name + "-slave"): slave_nodes.append(inst) if any((master_nodes, slave_nodes)): - print ("Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes))) + print "Found %d master(s), %d slaves" % (len(master_nodes), len(slave_nodes)) if master_nodes != [] or not die_on_error: return (master_nodes, slave_nodes) else: @@ -626,43 +626,43 @@ def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): def get_num_disks(instance_type): # From http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html # Updated 2014-6-20 + # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { - "m1.small": 1, - "m1.medium": 1, - "m1.large": 2, - "m1.xlarge": 4, - "t1.micro": 1, "c1.medium": 1, "c1.xlarge": 4, - "m2.xlarge": 1, - "m2.2xlarge": 1, - "m2.4xlarge": 2, + "c3.2xlarge": 2, + "c3.4xlarge": 2, + "c3.8xlarge": 2, + "c3.large": 2, + "c3.xlarge": 2, "cc1.4xlarge": 2, "cc2.8xlarge": 4, "cg1.4xlarge": 2, - "hs1.8xlarge": 24, "cr1.8xlarge": 2, + "g2.2xlarge": 1, "hi1.4xlarge": 2, - "m3.medium": 1, - "m3.large": 1, - "m3.xlarge": 2, - "m3.2xlarge": 2, - "i2.xlarge": 1, + "hs1.8xlarge": 24, "i2.2xlarge": 2, "i2.4xlarge": 4, "i2.8xlarge": 8, - "c3.large": 2, - "c3.xlarge": 2, - "c3.2xlarge": 2, - "c3.4xlarge": 2, - "c3.8xlarge": 2, - "r3.large": 1, - "r3.xlarge": 1, + "i2.xlarge": 1, + "m1.large": 2, + "m1.medium": 1, + "m1.small": 1, + "m1.xlarge": 4, + "m2.2xlarge": 1, + "m2.4xlarge": 2, + "m2.xlarge": 1, + "m3.2xlarge": 2, + "m3.large": 1, + "m3.medium": 1, + "m3.xlarge": 2, "r3.2xlarge": 1, "r3.4xlarge": 1, "r3.8xlarge": 2, - "g2.2xlarge": 1, - "t1.micro": 0 + "r3.large": 1, + "r3.xlarge": 1, + "t1.micro": 0, } if instance_type in disks_by_instance: return disks_by_instance[instance_type] @@ -785,7 +785,7 @@ def ssh(host, opts, command): ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), stringify_command(command)]) except subprocess.CalledProcessError as e: - if (tries > 5): + if tries > 5: # If this was an ssh failure, provide the user with hints. if e.returncode == 255: raise UsageError( @@ -820,18 +820,18 @@ def ssh_read(host, opts, command): ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) -def ssh_write(host, opts, command, input): +def ssh_write(host, opts, command, arguments): tries = 0 while True: proc = subprocess.Popen( ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], stdin=subprocess.PIPE) - proc.stdin.write(input) + proc.stdin.write(arguments) proc.stdin.close() status = proc.wait() if status == 0: break - elif (tries > 5): + elif tries > 5: raise RuntimeError("ssh_write failed with error %s" % proc.returncode) else: print >> stderr, \ From 657bdff41a27568a981b3e342ad380fe92aa08a0 Mon Sep 17 00:00:00 2001 From: "Zhang, Liye" Date: Mon, 29 Sep 2014 01:13:15 -0700 Subject: [PATCH 39/50] [CORE] Bugfix: LogErr format in DAGScheduler.scala Author: Zhang, Liye Closes #2572 from liyezhang556520/DAGLogErr and squashes the following commits: 5be2491 [Zhang, Liye] Bugfix: LogErr format in DAGScheduler.scala --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 70c235dffff70..5a96f52a10cd4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1209,7 +1209,7 @@ class DAGScheduler( .format(job.jobId, stageId)) } else if (jobsForStage.get.size == 1) { if (!stageIdToStage.contains(stageId)) { - logError("Missing Stage for stage with id $stageId") + logError(s"Missing Stage for stage with id $stageId") } else { // This is the only job that uses this stage, so fail the stage if it is running. val stage = stageIdToStage(stageId) From aedd251c54fd130fe6e2f28d7587d39136e7ad1c Mon Sep 17 00:00:00 2001 From: Nicholas Chammas Date: Mon, 29 Sep 2014 10:45:08 -0700 Subject: [PATCH 40/50] [EC2] Sort long, manually-inputted dictionaries Similar to the work done in #2571, this PR just sorts the remaining manually-inputted dicts in the EC2 script so they are easier to maintain. Author: Nicholas Chammas Closes #2578 from nchammas/ec2-dict-sort and squashes the following commits: f55c692 [Nicholas Chammas] sort long dictionaries --- ec2/spark_ec2.py | 69 ++++++++++++++++++++++++++---------------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index 5776d0b519309..941dfb988b9fb 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -217,8 +217,15 @@ def is_active(instance): # Return correct versions of Spark and Shark, given the supplied Spark version def get_spark_shark_version(opts): spark_shark_map = { - "0.7.3": "0.7.1", "0.8.0": "0.8.0", "0.8.1": "0.8.1", "0.9.0": "0.9.0", "0.9.1": "0.9.1", - "1.0.0": "1.0.0", "1.0.1": "1.0.1", "1.0.2": "1.0.2", "1.1.0": "1.1.0" + "0.7.3": "0.7.1", + "0.8.0": "0.8.0", + "0.8.1": "0.8.1", + "0.9.0": "0.9.0", + "0.9.1": "0.9.1", + "1.0.0": "1.0.0", + "1.0.1": "1.0.1", + "1.0.2": "1.0.2", + "1.1.0": "1.1.0", } version = opts.spark_version.replace("v", "") if version not in spark_shark_map: @@ -227,49 +234,49 @@ def get_spark_shark_version(opts): return (version, spark_shark_map[version]) -# Attempt to resolve an appropriate AMI given the architecture and -# region of the request. -# Information regarding Amazon Linux AMI instance type was update on 2014-6-20: -# http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ +# Attempt to resolve an appropriate AMI given the architecture and region of the request. +# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ +# Last Updated: 2014-06-20 +# For easy maintainability, please keep this manually-inputted dictionary sorted by key. def get_spark_ami(opts): instance_types = { - "m1.small": "pvm", - "m1.medium": "pvm", - "m1.large": "pvm", - "m1.xlarge": "pvm", - "t1.micro": "pvm", "c1.medium": "pvm", "c1.xlarge": "pvm", - "m2.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", + "c3.2xlarge": "pvm", + "c3.4xlarge": "pvm", + "c3.8xlarge": "pvm", + "c3.large": "pvm", + "c3.xlarge": "pvm", "cc1.4xlarge": "hvm", "cc2.8xlarge": "hvm", "cg1.4xlarge": "hvm", - "hs1.8xlarge": "pvm", - "hi1.4xlarge": "pvm", - "m3.medium": "hvm", - "m3.large": "hvm", - "m3.xlarge": "hvm", - "m3.2xlarge": "hvm", "cr1.8xlarge": "hvm", - "i2.xlarge": "hvm", + "hi1.4xlarge": "pvm", + "hs1.8xlarge": "pvm", "i2.2xlarge": "hvm", "i2.4xlarge": "hvm", "i2.8xlarge": "hvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", + "i2.xlarge": "hvm", + "m1.large": "pvm", + "m1.medium": "pvm", + "m1.small": "pvm", + "m1.xlarge": "pvm", + "m2.2xlarge": "pvm", + "m2.4xlarge": "pvm", + "m2.xlarge": "pvm", + "m3.2xlarge": "hvm", + "m3.large": "hvm", + "m3.medium": "hvm", + "m3.xlarge": "hvm", "r3.2xlarge": "hvm", "r3.4xlarge": "hvm", "r3.8xlarge": "hvm", + "r3.large": "hvm", + "r3.xlarge": "hvm", + "t1.micro": "pvm", + "t2.medium": "hvm", "t2.micro": "hvm", "t2.small": "hvm", - "t2.medium": "hvm" } if opts.instance_type in instance_types: instance_type = instance_types[opts.instance_type] @@ -624,8 +631,8 @@ def wait_for_cluster(conn, wait_secs, master_nodes, slave_nodes): # Get number of local disks available for a given EC2 instance type. def get_num_disks(instance_type): - # From http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Updated 2014-6-20 + # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html + # Last Updated: 2014-06-20 # For easy maintainability, please keep this manually-inputted dictionary sorted by key. disks_by_instance = { "c1.medium": 1, From 587a0cd7ed964ebfca2c97924c4f1e363f1fd3cb Mon Sep 17 00:00:00 2001 From: Reza Zadeh Date: Mon, 29 Sep 2014 11:15:09 -0700 Subject: [PATCH 41/50] [MLlib] [SPARK-2885] DIMSUM: All-pairs similarity MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # All-pairs similarity via DIMSUM Compute all pairs of similar vectors using brute force approach, and also DIMSUM sampling approach. Laying down some notation: we are looking for all pairs of similar columns in an m x n RowMatrix whose entries are denoted a_ij, with the i’th row denoted r_i and the j’th column denoted c_j. There is an oversampling parameter labeled ɣ that should be set to 4 log(n)/s to get provably correct results (with high probability), where s is the similarity threshold. The algorithm is stated with a Map and Reduce, with proofs of correctness and efficiency in published papers [1] [2]. The reducer is simply the summation reducer. The mapper is more interesting, and is also the heart of the scheme. As an exercise, you should try to see why in expectation, the map-reduce below outputs cosine similarities. ![dimsumv2](https://cloud.githubusercontent.com/assets/3220351/3807272/d1d9514e-1c62-11e4-9f12-3cfdb1d78b3a.png) [1] Bosagh-Zadeh, Reza and Carlsson, Gunnar (2013), Dimension Independent Matrix Square using MapReduce, arXiv:1304.1467 http://arxiv.org/abs/1304.1467 [2] Bosagh-Zadeh, Reza and Goel, Ashish (2012), Dimension Independent Similarity Computation, arXiv:1206.2082 http://arxiv.org/abs/1206.2082 # Testing Tests for all invocations included. Added L1 and L2 norm computation to MultivariateStatisticalSummary since it was needed. Added tests for both of them. Author: Reza Zadeh Author: Xiangrui Meng Closes #1778 from rezazadeh/dimsumv2 and squashes the following commits: 404c64c [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 4eb71c6 [Reza Zadeh] Add excludes for normL1 and normL2 ee8bd65 [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 976ddd4 [Reza Zadeh] Broadcast colMags. Avoid div by zero. 3467cff [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 aea0247 [Reza Zadeh] Allow large thresholds to promote sparsity 9fe17c0 [Xiangrui Meng] organize imports 2196ba5 [Xiangrui Meng] Merge branch 'rezazadeh-dimsumv2' into dimsumv2 254ca08 [Reza Zadeh] Merge remote-tracking branch 'upstream/master' into dimsumv2 f2947e4 [Xiangrui Meng] some optimization 3c4cf41 [Xiangrui Meng] Merge branch 'master' into rezazadeh-dimsumv2 0e4eda4 [Reza Zadeh] Use partition index for RNG 251bb9c [Reza Zadeh] Documentation 25e9d0d [Reza Zadeh] Line length for style fb296f6 [Reza Zadeh] renamed to normL1 and normL2 3764983 [Reza Zadeh] Documentation e9c6791 [Reza Zadeh] New interface and documentation 613f261 [Reza Zadeh] Column magnitude summary 75a0b51 [Reza Zadeh] Use Ints instead of Longs in the shuffle 0f12ade [Reza Zadeh] Style changes eb1dc20 [Reza Zadeh] Use Double.PositiveInfinity instead of Double.Max f56a882 [Reza Zadeh] Remove changes to MultivariateOnlineSummarizer dbc55ba [Reza Zadeh] Make colMagnitudes a method in RowMatrix 41e8ece [Reza Zadeh] style changes 139c8e1 [Reza Zadeh] Syntax changes 029aa9c [Reza Zadeh] javadoc and new test 75edb25 [Reza Zadeh] All tests passing! 05e59b8 [Reza Zadeh] Add test 502ce52 [Reza Zadeh] new interface 654c4fb [Reza Zadeh] default methods 3726ca9 [Reza Zadeh] Remove MatrixAlgebra 6bebabb [Reza Zadeh] remove changes to MatrixSuite 5b8cd7d [Reza Zadeh] Initial files --- .../mllib/linalg/distributed/RowMatrix.scala | 171 +++++++++++++++++- .../stat/MultivariateOnlineSummarizer.scala | 38 +++- .../stat/MultivariateStatisticalSummary.scala | 10 + .../linalg/distributed/RowMatrixSuite.scala | 37 ++++ project/MimaExcludes.scala | 9 +- 5 files changed, 259 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 4174f45d231c7..8380058cf9b41 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -19,17 +19,21 @@ package org.apache.spark.mllib.linalg.distributed import java.util.Arrays -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV} -import breeze.linalg.{svd => brzSvd, axpy => brzAxpy} +import scala.collection.mutable.ListBuffer + +import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, + svd => brzSvd} import breeze.numerics.{sqrt => brzSqrt} import com.github.fommil.netlib.BLAS.{getInstance => blas} +import org.apache.spark.Logging +import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ -import org.apache.spark.rdd.RDD -import org.apache.spark.Logging import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.storage.StorageLevel /** @@ -411,6 +415,165 @@ class RowMatrix( new RowMatrix(AB, nRows, B.numCols) } + /** + * Compute all cosine similarities between columns of this matrix using the brute-force + * approach of computing normalized dot products. + * + * @return An n x n sparse upper-triangular matrix of cosine similarities between + * columns of this matrix. + */ + def columnSimilarities(): CoordinateMatrix = { + columnSimilarities(0.0) + } + + /** + * Compute similarities between columns of this matrix using a sampling approach. + * + * The threshold parameter is a trade-off knob between estimate quality and computational cost. + * + * Setting a threshold of 0 guarantees deterministic correct results, but comes at exactly + * the same cost as the brute-force approach. Setting the threshold to positive values + * incurs strictly less computational cost than the brute-force approach, however the + * similarities computed will be estimates. + * + * The sampling guarantees relative-error correctness for those pairs of columns that have + * similarity greater than the given similarity threshold. + * + * To describe the guarantee, we set some notation: + * Let A be the smallest in magnitude non-zero element of this matrix. + * Let B be the largest in magnitude non-zero element of this matrix. + * Let L be the maximum number of non-zeros per row. + * + * For example, for {0,1} matrices: A=B=1. + * Another example, for the Netflix matrix: A=1, B=5 + * + * For those column pairs that are above the threshold, + * the computed similarity is correct to within 20% relative error with probability + * at least 1 - (0.981)^10/B^ + * + * The shuffle size is bounded by the *smaller* of the following two expressions: + * + * O(n log(n) L / (threshold * A)) + * O(m L^2^) + * + * The latter is the cost of the brute-force approach, so for non-zero thresholds, + * the cost is always cheaper than the brute-force approach. + * + * @param threshold Set to 0 for deterministic guaranteed correctness. + * Similarities above this threshold are estimated + * with the cost vs estimate quality trade-off described above. + * @return An n x n sparse upper-triangular matrix of cosine similarities + * between columns of this matrix. + */ + def columnSimilarities(threshold: Double): CoordinateMatrix = { + require(threshold >= 0, s"Threshold cannot be negative: $threshold") + + if (threshold > 1) { + logWarning(s"Threshold is greater than 1: $threshold " + + "Computation will be more efficient with promoted sparsity, " + + " however there is no correctness guarantee.") + } + + val gamma = if (threshold < 1e-6) { + Double.PositiveInfinity + } else { + 10 * math.log(numCols()) / threshold + } + + columnSimilaritiesDIMSUM(computeColumnSummaryStatistics().normL2.toArray, gamma) + } + + /** + * Find all similar columns using the DIMSUM sampling algorithm, described in two papers + * + * http://arxiv.org/abs/1206.2082 + * http://arxiv.org/abs/1304.1467 + * + * @param colMags A vector of column magnitudes + * @param gamma The oversampling parameter. For provable results, set to 10 * log(n) / s, + * where s is the smallest similarity score to be estimated, + * and n is the number of columns + * @return An n x n sparse upper-triangular matrix of cosine similarities + * between columns of this matrix. + */ + private[mllib] def columnSimilaritiesDIMSUM( + colMags: Array[Double], + gamma: Double): CoordinateMatrix = { + require(gamma > 1.0, s"Oversampling should be greater than 1: $gamma") + require(colMags.size == this.numCols(), "Number of magnitudes didn't match column dimension") + val sg = math.sqrt(gamma) // sqrt(gamma) used many times + + // Don't divide by zero for those columns with zero magnitude + val colMagsCorrected = colMags.map(x => if (x == 0) 1.0 else x) + + val sc = rows.context + val pBV = sc.broadcast(colMagsCorrected.map(c => sg / c)) + val qBV = sc.broadcast(colMagsCorrected.map(c => math.min(sg, c))) + + val sims = rows.mapPartitionsWithIndex { (indx, iter) => + val p = pBV.value + val q = qBV.value + + val rand = new XORShiftRandom(indx) + val scaled = new Array[Double](p.size) + iter.flatMap { row => + val buf = new ListBuffer[((Int, Int), Double)]() + row match { + case sv: SparseVector => + val nnz = sv.indices.size + var k = 0 + while (k < nnz) { + scaled(k) = sv.values(k) / q(sv.indices(k)) + k += 1 + } + k = 0 + while (k < nnz) { + val i = sv.indices(k) + val iVal = scaled(k) + if (iVal != 0 && rand.nextDouble() < p(i)) { + var l = k + 1 + while (l < nnz) { + val j = sv.indices(l) + val jVal = scaled(l) + if (jVal != 0 && rand.nextDouble() < p(j)) { + buf += (((i, j), iVal * jVal)) + } + l += 1 + } + } + k += 1 + } + case dv: DenseVector => + val n = dv.values.size + var i = 0 + while (i < n) { + scaled(i) = dv.values(i) / q(i) + i += 1 + } + i = 0 + while (i < n) { + val iVal = scaled(i) + if (iVal != 0 && rand.nextDouble() < p(i)) { + var j = i + 1 + while (j < n) { + val jVal = scaled(j) + if (jVal != 0 && rand.nextDouble() < p(j)) { + buf += (((i, j), iVal * jVal)) + } + j += 1 + } + } + i += 1 + } + } + buf + } + }.reduceByKey(_ + _).map { case ((i, j), sim) => + MatrixEntry(i.toLong, j.toLong, sim) + } + new CoordinateMatrix(sims, numCols(), numCols()) + } + private[mllib] override def toBreeze(): BDM[Double] = { val m = numRows().toInt val n = numCols().toInt diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 7d845c44365dd..3025d4837cab4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -42,6 +42,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S private var n = 0 private var currMean: BDV[Double] = _ private var currM2n: BDV[Double] = _ + private var currM2: BDV[Double] = _ + private var currL1: BDV[Double] = _ private var totalCnt: Long = 0 private var nnz: BDV[Double] = _ private var currMax: BDV[Double] = _ @@ -60,6 +62,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currMean = BDV.zeros[Double](n) currM2n = BDV.zeros[Double](n) + currM2 = BDV.zeros[Double](n) + currL1 = BDV.zeros[Double](n) nnz = BDV.zeros[Double](n) currMax = BDV.fill(n)(Double.MinValue) currMin = BDV.fill(n)(Double.MaxValue) @@ -81,6 +85,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S val tmpPrevMean = currMean(i) currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0) currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean) + currM2(i) += value * value + currL1(i) += math.abs(value) nnz(i) += 1.0 } @@ -97,7 +103,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S * @return This MultivariateOnlineSummarizer object. */ def merge(other: MultivariateOnlineSummarizer): this.type = { - if (this.totalCnt != 0 && other.totalCnt != 0) { + if (this.totalCnt != 0 && other.totalCnt != 0) { require(n == other.n, s"Dimensions mismatch when merging with another summarizer. " + s"Expecting $n but got ${other.n}.") totalCnt += other.totalCnt @@ -114,6 +120,15 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / (nnz(i) + other.nnz(i)) } + // merge m2 together + if (nnz(i) + other.nnz(i) != 0.0) { + currM2(i) += other.currM2(i) + } + // merge l1 together + if (nnz(i) + other.nnz(i) != 0.0) { + currL1(i) += other.currL1(i) + } + if (currMax(i) < other.currMax(i)) { currMax(i) = other.currMax(i) } @@ -127,6 +142,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S this.n = other.n this.currMean = other.currMean.copy this.currM2n = other.currM2n.copy + this.currM2 = other.currM2.copy + this.currL1 = other.currL1.copy this.totalCnt = other.totalCnt this.nnz = other.nnz.copy this.currMax = other.currMax.copy @@ -198,4 +215,23 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S } Vectors.fromBreeze(currMin) } + + override def normL2: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + + val realMagnitude = BDV.zeros[Double](n) + + var i = 0 + while (i < currM2.size) { + realMagnitude(i) = math.sqrt(currM2(i)) + i += 1 + } + + Vectors.fromBreeze(realMagnitude) + } + + override def normL1: Vector = { + require(totalCnt > 0, s"Nothing has been added to this summarizer.") + Vectors.fromBreeze(currL1) + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala index f9eb343da2b82..6a364c93284af 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala @@ -53,4 +53,14 @@ trait MultivariateStatisticalSummary { * Minimum value of each column. */ def min: Vector + + /** + * Euclidean magnitude of each column + */ + def normL2: Vector + + /** + * L1 norm of each column + */ + def normL1: Vector } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 1d3a3221365cc..63f3ed58c0d4d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -95,6 +95,40 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { } } + test("similar columns") { + val colMags = Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)) + val expected = BDM( + (0.0, 54.0, 72.0), + (0.0, 0.0, 78.0), + (0.0, 0.0, 0.0)) + + for (i <- 0 until n; j <- 0 until n) { + expected(i, j) /= (colMags(i) * colMags(j)) + } + + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.columnSimilarities(0.11).toBreeze() + for (i <- 0 until n; j <- 0 until n) { + if (expected(i, j) > 0) { + val actual = expected(i, j) + val estimate = G(i, j) + assert(math.abs(actual - estimate) / actual < 0.2, + s"Similarities not close enough: $actual vs $estimate") + } + } + } + + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.columnSimilarities() + assert(closeToZero(G.toBreeze() - expected)) + } + + for (mat <- Seq(denseMat, sparseMat)) { + val G = mat.columnSimilaritiesDIMSUM(colMags.toArray, 150.0) + assert(closeToZero(G.toBreeze() - expected)) + } + } + test("svd of a full-rank matrix") { for (mat <- Seq(denseMat, sparseMat)) { for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { @@ -190,6 +224,9 @@ class RowMatrixSuite extends FunSuite with LocalSparkContext { assert(summary.numNonzeros === Vectors.dense(3.0, 3.0, 4.0), "nnz mismatch") assert(summary.max === Vectors.dense(9.0, 7.0, 8.0), "max mismatch") assert(summary.min === Vectors.dense(0.0, 0.0, 1.0), "column mismatch.") + assert(summary.normL2 === Vectors.dense(Math.sqrt(126), Math.sqrt(66), Math.sqrt(94)), + "magnitude mismatch.") + assert(summary.normL1 === Vectors.dense(18.0, 12.0, 16.0), "L1 norm mismatch") } } } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 3280e662fa0b1..1adfaa18c6202 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -39,7 +39,14 @@ object MimaExcludes { MimaBuild.excludeSparkPackage("graphx") ) ++ MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++ - MimaBuild.excludeSparkClass("mllib.linalg.Vector") + MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++ + Seq( + // Added normL1 and normL2 to trait MultivariateStatisticalSummary + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2") + ) case v if v.startsWith("1.1") => Seq( From dab1b0ae29a6d3017bdca23464f22a51d51eaae1 Mon Sep 17 00:00:00 2001 From: jerryshao Date: Mon, 29 Sep 2014 11:25:32 -0700 Subject: [PATCH 42/50] [SPARK-3032][Shuffle] Fix key comparison integer overflow introduced sorting exception Previous key comparison in `ExternalSorter` will get wrong sorting result or exception when key comparison overflows, details can be seen in [SPARK-3032](https://issues.apache.org/jira/browse/SPARK-3032). Here fix this and add a unit test to prove it. Author: jerryshao Closes #2514 from jerryshao/SPARK-3032 and squashes the following commits: 6f3c302 [jerryshao] Improve the unit test according to comments 01911e6 [jerryshao] Change the test to show the contract violate exception 83acb38 [jerryshao] Minor changes according to comments fa2a08f [jerryshao] Fix key comparison integer overflow introduced sorting exception --- .../util/collection/ExternalSorter.scala | 2 +- .../util/collection/ExternalSorterSuite.scala | 55 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 0a152cb97ad9e..644fa36818647 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -144,7 +144,7 @@ private[spark] class ExternalSorter[K, V, C]( override def compare(a: K, b: K): Int = { val h1 = if (a == null) 0 else a.hashCode() val h2 = if (b == null) 0 else b.hashCode() - h1 - h2 + if (h1 < h2) -1 else if (h1 == h2) 0 else 1 } }) diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index 706faed980f31..f26e40fbd4b36 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -24,6 +24,8 @@ import org.scalatest.{PrivateMethodTester, FunSuite} import org.apache.spark._ import org.apache.spark.SparkContext._ +import scala.util.Random + class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester { private def createSparkConf(loadDefaults: Boolean): SparkConf = { val conf = new SparkConf(loadDefaults) @@ -707,4 +709,57 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None) assertDidNotBypassMergeSort(sorter4) } + + test("sort without breaking sorting contracts") { + val conf = createSparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.01") + conf.set("spark.shuffle.manager", "sort") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + // Using wrongOrdering to show integer overflow introduced exception. + val rand = new Random(100L) + val wrongOrdering = new Ordering[String] { + override def compare(a: String, b: String) = { + val h1 = if (a == null) 0 else a.hashCode() + val h2 = if (b == null) 0 else b.hashCode() + h1 - h2 + } + } + + val testData = Array.tabulate(100000) { _ => rand.nextInt().toString } + + val sorter1 = new ExternalSorter[String, String, String]( + None, None, Some(wrongOrdering), None) + val thrown = intercept[IllegalArgumentException] { + sorter1.insertAll(testData.iterator.map(i => (i, i))) + sorter1.iterator + } + + assert(thrown.getClass() === classOf[IllegalArgumentException]) + assert(thrown.getMessage().contains("Comparison method violates its general contract")) + sorter1.stop() + + // Using aggregation and external spill to make sure ExternalSorter using + // partitionKeyComparator. + def createCombiner(i: String) = ArrayBuffer(i) + def mergeValue(c: ArrayBuffer[String], i: String) = c += i + def mergeCombiners(c1: ArrayBuffer[String], c2: ArrayBuffer[String]) = c1 ++= c2 + + val agg = new Aggregator[String, String, ArrayBuffer[String]]( + createCombiner, mergeValue, mergeCombiners) + + val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]]( + Some(agg), None, None, None) + sorter2.insertAll(testData.iterator.map(i => (i, i))) + + // To validate the hash ordering of key + var minKey = Int.MinValue + sorter2.iterator.foreach { case (k, v) => + val h = k.hashCode() + assert(h >= minKey) + minKey = h + } + + sorter2.stop() + } } From e43c72fe04d4fbf2a108b456d533e641b71b0a2a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 12:38:24 -0700 Subject: [PATCH 43/50] Add more debug message for ManagedBuffer This is to help debug the error reported at http://apache-spark-user-list.1001560.n3.nabble.com/SQL-queries-fail-in-1-2-0-SNAPSHOT-td15327.html Author: Reynold Xin Closes #2580 from rxin/buffer-debug and squashes the following commits: 5814292 [Reynold Xin] Logging close() in case close() fails. 323dfec [Reynold Xin] Add more debug message. --- .../apache/spark/network/ManagedBuffer.scala | 43 ++++++++++++++++--- .../scala/org/apache/spark/util/Utils.scala | 14 ++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala index e990c1da6730f..a4409181ec907 100644 --- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala +++ b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala @@ -17,15 +17,17 @@ package org.apache.spark.network -import java.io.{FileInputStream, RandomAccessFile, File, InputStream} +import java.io._ import java.nio.ByteBuffer import java.nio.channels.FileChannel import java.nio.channels.FileChannel.MapMode +import scala.util.Try + import com.google.common.io.ByteStreams import io.netty.buffer.{ByteBufInputStream, ByteBuf} -import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.util.{ByteBufferInputStream, Utils} /** @@ -71,18 +73,47 @@ final class FileSegmentManagedBuffer(val file: File, val offset: Long, val lengt try { channel = new RandomAccessFile(file, "r").getChannel channel.map(MapMode.READ_ONLY, offset, length) + } catch { + case e: IOException => + Try(channel.size).toOption match { + case Some(fileLen) => + throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) + case None => + throw new IOException(s"Error in opening $this", e) + } } finally { if (channel != null) { - channel.close() + Utils.tryLog(channel.close()) } } } override def inputStream(): InputStream = { - val is = new FileInputStream(file) - is.skip(offset) - ByteStreams.limit(is, length) + var is: FileInputStream = null + try { + is = new FileInputStream(file) + is.skip(offset) + ByteStreams.limit(is, length) + } catch { + case e: IOException => + if (is != null) { + Utils.tryLog(is.close()) + } + Try(file.length).toOption match { + case Some(fileLen) => + throw new IOException(s"Error in reading $this (actual file length $fileLen)", e) + case None => + throw new IOException(s"Error in opening $this", e) + } + case e: Throwable => + if (is != null) { + Utils.tryLog(is.close()) + } + throw e + } } + + override def toString: String = s"${getClass.getName}($file, $offset, $length)" } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2755887feeeff..10d440828e323 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1304,6 +1304,20 @@ private[spark] object Utils extends Logging { } } + /** Executes the given block in a Try, logging any uncaught exceptions. */ + def tryLog[T](f: => T): Try[T] = { + try { + val res = f + scala.util.Success(res) + } catch { + case ct: ControlThrowable => + throw ct + case t: Throwable => + logError(s"Uncaught exception in thread ${Thread.currentThread().getName}", t) + scala.util.Failure(t) + } + } + /** Returns true if the given exception was fatal. See docs for scala.util.control.NonFatal. */ def isFatalError(e: Throwable): Boolean = { e match { From 0bbe7faeffa17577ae8a33dfcd8c4c783db5c909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?baishuo=28=E7=99=BD=E7=A1=95=29?= Date: Mon, 29 Sep 2014 15:51:55 -0700 Subject: [PATCH 44/50] [SPARK-3007][SQL]Add Dynamic Partition support to Spark Sql hive MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit a new PR base on new master. changes are the same as https://github.com/apache/spark/pull/1919 Author: baishuo(白硕) Author: baishuo Author: Cheng Lian Closes #2226 from baishuo/patch-3007 and squashes the following commits: e69ce88 [Cheng Lian] Adds tests to verify dynamic partitioning folder layout b20a3dc [Cheng Lian] Addresses @yhuai's comments 096bbbc [baishuo(白硕)] Merge pull request #1 from liancheng/refactor-dp 1093c20 [Cheng Lian] Adds more tests 5004542 [Cheng Lian] Minor refactoring fae9eff [Cheng Lian] Refactors InsertIntoHiveTable to a Command 528e84c [Cheng Lian] Fixes typo in test name, regenerated golden answer files c464b26 [Cheng Lian] Refactors dynamic partitioning support 5033928 [baishuo] pass check style 2201c75 [baishuo] use HiveConf.DEFAULTPARTITIONNAME to replace hive.exec.default.partition.name b47c9bf [baishuo] modify according micheal's advice c3ab36d [baishuo] modify for some bad indentation 7ce2d9f [baishuo] modify code to pass scala style checks 37c1c43 [baishuo] delete a empty else branch 66e33fc [baishuo] do a little modify 88d0110 [baishuo] update file after test a3961d9 [baishuo(白硕)] Update Cast.scala f7467d0 [baishuo(白硕)] Update InsertIntoHiveTable.scala c1a59dd [baishuo(白硕)] Update Cast.scala 0e18496 [baishuo(白硕)] Update HiveQuerySuite.scala 60f70aa [baishuo(白硕)] Update InsertIntoHiveTable.scala 0a50db9 [baishuo(白硕)] Update HiveCompatibilitySuite.scala 491c7d0 [baishuo(白硕)] Update InsertIntoHiveTable.scala a2374a8 [baishuo(白硕)] Update InsertIntoHiveTable.scala 701a814 [baishuo(白硕)] Update SparkHadoopWriter.scala dc24c41 [baishuo(白硕)] Update HiveQl.scala --- .../execution/HiveCompatibilitySuite.scala | 17 ++ .../org/apache/spark/SparkHadoopWriter.scala | 195 ---------------- .../org/apache/spark/sql/hive/HiveQl.scala | 5 - .../hive/execution/InsertIntoHiveTable.scala | 207 +++++++++-------- .../spark/sql/hive/hiveWriterContainers.scala | 217 ++++++++++++++++++ ...rtition-0-be33aaa7253c8f248ff3921cd7dae340 | 0 ...rtition-1-640552dd462707563fd255a713f83b41 | 0 ...rtition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 | 1 + ...rtition-3-b7f7fa7ebf666f4fee27e149d8c6961f | 0 ...rtition-4-8bdb71ad8cb3cc3026043def2525de3a | 0 ...rtition-5-c630dce438f3792e7fb0f523fbbb3e1e | 0 ...rtition-6-7abc9ec8a36cdc5e89e955265a7fd7cf | 0 ...rtition-7-be33aaa7253c8f248ff3921cd7dae340 | 0 .../sql/hive/execution/HiveQuerySuite.scala | 100 +++++++- 14 files changed, 443 insertions(+), 299 deletions(-) delete mode 100644 sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf create mode 100644 sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 556c984ad392b..35e9c9939d4b7 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -220,6 +220,23 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { */ override def whiteList = Seq( "add_part_exist", + "dynamic_partition_skip_default", + "infer_bucket_sort_dyn_part", + "load_dyn_part1", + "load_dyn_part2", + "load_dyn_part3", + "load_dyn_part4", + "load_dyn_part5", + "load_dyn_part6", + "load_dyn_part7", + "load_dyn_part8", + "load_dyn_part9", + "load_dyn_part10", + "load_dyn_part11", + "load_dyn_part12", + "load_dyn_part13", + "load_dyn_part14", + "load_dyn_part14_win", "add_part_multiple", "add_partition_no_whitelist", "add_partition_with_whitelist", diff --git a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala deleted file mode 100644 index ab7862f4f9e06..0000000000000 --- a/sql/hive/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.io.IOException -import java.text.NumberFormat -import java.util.Date - -import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} -import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} -import org.apache.hadoop.hive.ql.plan.FileSinkDesc -import org.apache.hadoop.mapred._ -import org.apache.hadoop.io.Writable - -import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} - -/** - * Internal helper class that saves an RDD using a Hive OutputFormat. - * It is based on [[SparkHadoopWriter]]. - */ -private[hive] class SparkHiveHadoopWriter( - @transient jobConf: JobConf, - fileSinkConf: FileSinkDesc) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { - - private val now = new Date() - private val conf = new SerializableWritable(jobConf) - - private var jobID = 0 - private var splitID = 0 - private var attemptID = 0 - private var jID: SerializableWritable[JobID] = null - private var taID: SerializableWritable[TaskAttemptID] = null - - @transient private var writer: FileSinkOperator.RecordWriter = null - @transient private var format: HiveOutputFormat[AnyRef, Writable] = null - @transient private var committer: OutputCommitter = null - @transient private var jobContext: JobContext = null - @transient private var taskContext: TaskAttemptContext = null - - def preSetup() { - setIDs(0, 0, 0) - setConfParams() - - val jCtxt = getJobContext() - getOutputCommitter().setupJob(jCtxt) - } - - - def setup(jobid: Int, splitid: Int, attemptid: Int) { - setIDs(jobid, splitid, attemptid) - setConfParams() - } - - def open() { - val numfmt = NumberFormat.getInstance() - numfmt.setMinimumIntegerDigits(5) - numfmt.setGroupingUsed(false) - - val extension = Utilities.getFileExtension( - conf.value, - fileSinkConf.getCompressed, - getOutputFormat()) - - val outputName = "part-" + numfmt.format(splitID) + extension - val path = FileOutputFormat.getTaskOutputPath(conf.value, outputName) - - getOutputCommitter().setupTask(getTaskContext()) - writer = HiveFileFormatUtils.getHiveRecordWriter( - conf.value, - fileSinkConf.getTableInfo, - conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], - fileSinkConf, - path, - null) - } - - def write(value: Writable) { - if (writer != null) { - writer.write(value) - } else { - throw new IOException("Writer is null, open() has not been called") - } - } - - def close() { - // Seems the boolean value passed into close does not matter. - writer.close(false) - } - - def commit() { - val taCtxt = getTaskContext() - val cmtr = getOutputCommitter() - if (cmtr.needsTaskCommit(taCtxt)) { - try { - cmtr.commitTask(taCtxt) - logInfo (taID + ": Committed") - } catch { - case e: IOException => - logError("Error committing the output of task: " + taID.value, e) - cmtr.abortTask(taCtxt) - throw e - } - } else { - logWarning ("No need to commit output of task: " + taID.value) - } - } - - def commitJob() { - // always ? Or if cmtr.needsTaskCommit ? - val cmtr = getOutputCommitter() - cmtr.commitJob(getJobContext()) - } - - // ********* Private Functions ********* - - private def getOutputFormat(): HiveOutputFormat[AnyRef,Writable] = { - if (format == null) { - format = conf.value.getOutputFormat() - .asInstanceOf[HiveOutputFormat[AnyRef,Writable]] - } - format - } - - private def getOutputCommitter(): OutputCommitter = { - if (committer == null) { - committer = conf.value.getOutputCommitter - } - committer - } - - private def getJobContext(): JobContext = { - if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) - } - jobContext - } - - private def getTaskContext(): TaskAttemptContext = { - if (taskContext == null) { - taskContext = newTaskAttemptContext(conf.value, taID.value) - } - taskContext - } - - private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { - jobID = jobId - splitID = splitId - attemptID = attemptId - - jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) - taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) - } - - private def setConfParams() { - conf.value.set("mapred.job.id", jID.value.toString) - conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) - conf.value.set("mapred.task.id", taID.value.toString) - conf.value.setBoolean("mapred.task.is.map", true) - conf.value.setInt("mapred.task.partition", splitID) - } -} - -private[hive] object SparkHiveHadoopWriter { - def createPathFromString(path: String, conf: JobConf): Path = { - if (path == null) { - throw new IllegalArgumentException("Output path is null") - } - val outputPath = new Path(path) - val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { - throw new IllegalArgumentException("Incorrectly formatted output path") - } - outputPath.makeQualified(fs) - } -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 0aa6292c0184e..4e30e6e06fe21 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -837,11 +837,6 @@ private[hive] object HiveQl { cleanIdentifier(key.toLowerCase) -> None }.toMap).getOrElse(Map.empty) - if (partitionKeys.values.exists(p => p.isEmpty)) { - throw new NotImplementedError(s"Do not support INSERT INTO/OVERWRITE with" + - s"dynamic partitioning.") - } - InsertIntoTable(UnresolvedRelation(db, tableName, None), partitionKeys, query, overwrite) case a: ASTNode => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index a284a91a91e31..3d2ee010696f6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -19,27 +19,25 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConversions._ -import java.util.{HashMap => JHashMap} - import org.apache.hadoop.hive.common.`type`.{HiveDecimal, HiveVarchar} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.metastore.MetaStoreUtils -import org.apache.hadoop.hive.ql.Context import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.plan.{FileSinkDesc, TableDesc} +import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveDecimalObjectInspector -import org.apache.hadoop.hive.serde2.objectinspector.primitive.JavaHiveVarcharObjectInspector -import org.apache.hadoop.io.Writable +import org.apache.hadoop.hive.serde2.objectinspector._ +import org.apache.hadoop.hive.serde2.objectinspector.primitive.{JavaHiveDecimalObjectInspector, JavaHiveVarcharObjectInspector} import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.Row -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, SparkHiveHadoopWriter} +import org.apache.spark.sql.execution.{Command, SparkPlan, UnaryNode} +import org.apache.spark.sql.hive._ +import org.apache.spark.{SerializableWritable, SparkException, TaskContext} /** * :: DeveloperApi :: @@ -51,7 +49,7 @@ case class InsertIntoHiveTable( child: SparkPlan, overwrite: Boolean) (@transient sc: HiveContext) - extends UnaryNode { + extends UnaryNode with Command { @transient lazy val outputClass = newSerializer(table.tableDesc).getSerializedClass @transient private lazy val hiveContext = new Context(sc.hiveconf) @@ -101,66 +99,74 @@ case class InsertIntoHiveTable( } def saveAsHiveFile( - rdd: RDD[Writable], + rdd: RDD[Row], valueClass: Class[_], fileSinkConf: FileSinkDesc, - conf: JobConf, - isCompressed: Boolean) { - if (valueClass == null) { - throw new SparkException("Output value class not set") - } - conf.setOutputValueClass(valueClass) - if (fileSinkConf.getTableInfo.getOutputFileFormatClassName == null) { - throw new SparkException("Output format class not set") - } - // Doesn't work in Scala 2.9 due to what may be a generics bug - // TODO: Should we uncomment this for Scala 2.10? - // conf.setOutputFormat(outputFormatClass) - conf.set("mapred.output.format.class", fileSinkConf.getTableInfo.getOutputFileFormatClassName) + conf: SerializableWritable[JobConf], + writerContainer: SparkHiveWriterContainer) { + assert(valueClass != null, "Output value class not set") + conf.value.setOutputValueClass(valueClass) + + val outputFileFormatClassName = fileSinkConf.getTableInfo.getOutputFileFormatClassName + assert(outputFileFormatClassName != null, "Output format class not set") + conf.value.set("mapred.output.format.class", outputFileFormatClassName) + + val isCompressed = conf.value.getBoolean( + ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) + if (isCompressed) { // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", // and "mapred.output.compression.type" have no impact on ORC because it uses table properties // to store compression information. - conf.set("mapred.output.compress", "true") + conf.value.set("mapred.output.compress", "true") fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(conf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(conf.get("mapred.output.compression.type")) + fileSinkConf.setCompressCodec(conf.value.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(conf.value.get("mapred.output.compression.type")) } - conf.setOutputCommitter(classOf[FileOutputCommitter]) - FileOutputFormat.setOutputPath( - conf, - SparkHiveHadoopWriter.createPathFromString(fileSinkConf.getDirName, conf)) + conf.value.setOutputCommitter(classOf[FileOutputCommitter]) + FileOutputFormat.setOutputPath( + conf.value, + SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) - val writer = new SparkHiveHadoopWriter(conf, fileSinkConf) - writer.preSetup() + writerContainer.driverSideSetup() + sc.sparkContext.runJob(rdd, writeToFile _) + writerContainer.commitJob() + + // Note that this function is executed on executor side + def writeToFile(context: TaskContext, iterator: Iterator[Row]) { + val serializer = newSerializer(fileSinkConf.getTableInfo) + val standardOI = ObjectInspectorUtils + .getStandardObjectInspector( + fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, + ObjectInspectorCopyOption.JAVA) + .asInstanceOf[StructObjectInspector] + + val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray + val outputData = new Array[Any](fieldOIs.length) - def writeToFile(context: TaskContext, iter: Iterator[Writable]) { // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it // around by taking a mod. We expect that no task will be attempted 2 billion times. val attemptNumber = (context.attemptId % Int.MaxValue).toInt + writerContainer.executorSideSetup(context.stageId, context.partitionId, attemptNumber) - writer.setup(context.stageId, context.partitionId, attemptNumber) - writer.open() + iterator.foreach { row => + var i = 0 + while (i < fieldOIs.length) { + // TODO (lian) avoid per row dynamic dispatching and pattern matching cost in `wrap` + outputData(i) = wrap(row(i), fieldOIs(i)) + i += 1 + } - var count = 0 - while(iter.hasNext) { - val record = iter.next() - count += 1 - writer.write(record) + val writer = writerContainer.getLocalFileWriter(row) + writer.write(serializer.serialize(outputData, standardOI)) } - writer.close() - writer.commit() + writerContainer.close() } - - sc.sparkContext.runJob(rdd, writeToFile _) - writer.commitJob() } - override def execute() = result - /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the @@ -168,50 +174,57 @@ case class InsertIntoHiveTable( * * Note: this is run once and then kept to avoid double insertions. */ - private lazy val result: RDD[Row] = { - val childRdd = child.execute() - assert(childRdd != null) - + override protected[sql] lazy val sideEffectResult: Seq[Row] = { // Have to pass the TableDesc object to RDD.mapPartitions and then instantiate new serializer // instances within the closure, since Serializer is not serializable while TableDesc is. val tableDesc = table.tableDesc val tableLocation = table.hiveQlTable.getDataLocation val tmpLocation = hiveContext.getExternalTmpFileURI(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val rdd = childRdd.mapPartitions { iter => - val serializer = newSerializer(fileSinkConf.getTableInfo) - val standardOI = ObjectInspectorUtils - .getStandardObjectInspector( - fileSinkConf.getTableInfo.getDeserializer.getObjectInspector, - ObjectInspectorCopyOption.JAVA) - .asInstanceOf[StructObjectInspector] + val numDynamicPartitions = partition.values.count(_.isEmpty) + val numStaticPartitions = partition.values.count(_.nonEmpty) + val partitionSpec = partition.map { + case (key, Some(value)) => key -> value + case (key, None) => key -> "" + } - val fieldOIs = standardOI.getAllStructFieldRefs.map(_.getFieldObjectInspector).toArray - val outputData = new Array[Any](fieldOIs.length) - iter.map { row => - var i = 0 - while (i < row.length) { - // Casts Strings to HiveVarchars when necessary. - outputData(i) = wrap(row(i), fieldOIs(i)) - i += 1 - } + // All partition column names in the format of "//..." + val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns") + val partitionColumnNames = Option(partitionColumns).map(_.split("/")).orNull - serializer.serialize(outputData, standardOI) + // Validate partition spec if there exist any dynamic partitions + if (numDynamicPartitions > 0) { + // Report error if dynamic partitioning is not enabled + if (!sc.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) + } + + // Report error if dynamic partition strict mode is on but no static partition is found + if (numStaticPartitions == 0 && + sc.hiveconf.getVar(HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) + } + + // Report error if any static partition appears after a dynamic partition + val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) + isDynamic.init.zip(isDynamic.tail).find(_ == (true, false)).foreach { _ => + throw new SparkException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) } } - // ORC stores compression information in table properties. While, there are other formats - // (e.g. RCFile) that rely on hadoop configurations to store compression information. val jobConf = new JobConf(sc.hiveconf) - saveAsHiveFile( - rdd, - outputClass, - fileSinkConf, - jobConf, - sc.hiveconf.getBoolean("hive.exec.compress.output", false)) - - // TODO: Handle dynamic partitioning. + val jobConfSer = new SerializableWritable(jobConf) + + val writerContainer = if (numDynamicPartitions > 0) { + val dynamicPartColNames = partitionColumnNames.takeRight(numDynamicPartitions) + new SparkHiveDynamicPartitionWriterContainer(jobConf, fileSinkConf, dynamicPartColNames) + } else { + new SparkHiveWriterContainer(jobConf, fileSinkConf) + } + + saveAsHiveFile(child.execute(), outputClass, fileSinkConf, jobConfSer, writerContainer) + val outputPath = FileOutputFormat.getOutputPath(jobConf) // Have to construct the format of dbname.tablename. val qualifiedTableName = s"${table.databaseName}.${table.tableName}" @@ -220,10 +233,6 @@ case class InsertIntoHiveTable( // holdDDLTime will be true when TOK_HOLD_DDLTIME presents in the query as a hint. val holdDDLTime = false if (partition.nonEmpty) { - val partitionSpec = partition.map { - case (key, Some(value)) => key -> value - case (key, None) => key -> "" // Should not reach here right now. - } val partVals = MetaStoreUtils.getPvals(table.hiveQlTable.getPartCols, partitionSpec) db.validatePartitionNameCharacters(partVals) // inheritTableSpecs is set to true. It should be set to false for a IMPORT query @@ -231,14 +240,26 @@ case class InsertIntoHiveTable( val inheritTableSpecs = true // TODO: Correctly set isSkewedStoreAsSubdir. val isSkewedStoreAsSubdir = false - db.loadPartition( - outputPath, - qualifiedTableName, - partitionSpec, - overwrite, - holdDDLTime, - inheritTableSpecs, - isSkewedStoreAsSubdir) + if (numDynamicPartitions > 0) { + db.loadDynamicPartitions( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + numDynamicPartitions, + holdDDLTime, + isSkewedStoreAsSubdir + ) + } else { + db.loadPartition( + outputPath, + qualifiedTableName, + partitionSpec, + overwrite, + holdDDLTime, + inheritTableSpecs, + isSkewedStoreAsSubdir) + } } else { db.loadTable( outputPath, @@ -251,6 +272,6 @@ case class InsertIntoHiveTable( // however for now we return an empty list to simplify compatibility checks with hive, which // does not return anything for insert operations. // TODO: implement hive compatibility as rules. - sc.sparkContext.makeRDD(Nil, 1) + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala new file mode 100644 index 0000000000000..a667188fa53bd --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.io.IOException +import java.text.NumberFormat +import java.util.Date + +import scala.collection.mutable + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} +import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} +import org.apache.hadoop.hive.ql.plan.FileSinkDesc +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred._ + +import org.apache.spark.sql.Row +import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} + +/** + * Internal helper class that saves an RDD using a Hive OutputFormat. + * It is based on [[SparkHadoopWriter]]. + */ +private[hive] class SparkHiveWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc) + extends Logging + with SparkHadoopMapRedUtil + with Serializable { + + private val now = new Date() + protected val conf = new SerializableWritable(jobConf) + + private var jobID = 0 + private var splitID = 0 + private var attemptID = 0 + private var jID: SerializableWritable[JobID] = null + private var taID: SerializableWritable[TaskAttemptID] = null + + @transient private var writer: FileSinkOperator.RecordWriter = null + @transient private lazy val committer = conf.value.getOutputCommitter + @transient private lazy val jobContext = newJobContext(conf.value, jID.value) + @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) + @transient private lazy val outputFormat = + conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef,Writable]] + + def driverSideSetup() { + setIDs(0, 0, 0) + setConfParams() + committer.setupJob(jobContext) + } + + def executorSideSetup(jobId: Int, splitId: Int, attemptId: Int) { + setIDs(jobId, splitId, attemptId) + setConfParams() + committer.setupTask(taskContext) + initWriters() + } + + protected def getOutputName: String = { + val numberFormat = NumberFormat.getInstance() + numberFormat.setMinimumIntegerDigits(5) + numberFormat.setGroupingUsed(false) + val extension = Utilities.getFileExtension(conf.value, fileSinkConf.getCompressed, outputFormat) + "part-" + numberFormat.format(splitID) + extension + } + + def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = writer + + def close() { + // Seems the boolean value passed into close does not matter. + writer.close(false) + commit() + } + + def commitJob() { + committer.commitJob(jobContext) + } + + protected def initWriters() { + // NOTE this method is executed at the executor side. + // For Hive tables without partitions or with only static partitions, only 1 writer is needed. + writer = HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + fileSinkConf, + FileOutputFormat.getTaskOutputPath(conf.value, getOutputName), + Reporter.NULL) + } + + protected def commit() { + if (committer.needsTaskCommit(taskContext)) { + try { + committer.commitTask(taskContext) + logInfo (taID + ": Committed") + } catch { + case e: IOException => + logError("Error committing the output of task: " + taID.value, e) + committer.abortTask(taskContext) + throw e + } + } else { + logInfo("No need to commit output of task: " + taID.value) + } + } + + // ********* Private Functions ********* + + private def setIDs(jobId: Int, splitId: Int, attemptId: Int) { + jobID = jobId + splitID = splitId + attemptID = attemptId + + jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) + taID = new SerializableWritable[TaskAttemptID]( + new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + } + + private def setConfParams() { + conf.value.set("mapred.job.id", jID.value.toString) + conf.value.set("mapred.tip.id", taID.value.getTaskID.toString) + conf.value.set("mapred.task.id", taID.value.toString) + conf.value.setBoolean("mapred.task.is.map", true) + conf.value.setInt("mapred.task.partition", splitID) + } +} + +private[hive] object SparkHiveWriterContainer { + def createPathFromString(path: String, conf: JobConf): Path = { + if (path == null) { + throw new IllegalArgumentException("Output path is null") + } + val outputPath = new Path(path) + val fs = outputPath.getFileSystem(conf) + if (outputPath == null || fs == null) { + throw new IllegalArgumentException("Incorrectly formatted output path") + } + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } +} + +private[spark] class SparkHiveDynamicPartitionWriterContainer( + @transient jobConf: JobConf, + fileSinkConf: FileSinkDesc, + dynamicPartColNames: Array[String]) + extends SparkHiveWriterContainer(jobConf, fileSinkConf) { + + private val defaultPartName = jobConf.get( + ConfVars.DEFAULTPARTITIONNAME.varname, ConfVars.DEFAULTPARTITIONNAME.defaultVal) + + @transient private var writers: mutable.HashMap[String, FileSinkOperator.RecordWriter] = _ + + override protected def initWriters(): Unit = { + // NOTE: This method is executed at the executor side. + // Actual writers are created for each dynamic partition on the fly. + writers = mutable.HashMap.empty[String, FileSinkOperator.RecordWriter] + } + + override def close(): Unit = { + writers.values.foreach(_.close(false)) + commit() + } + + override def getLocalFileWriter(row: Row): FileSinkOperator.RecordWriter = { + val dynamicPartPath = dynamicPartColNames + .zip(row.takeRight(dynamicPartColNames.length)) + .map { case (col, rawVal) => + val string = String.valueOf(rawVal) + s"/$col=${if (rawVal == null || string.isEmpty) defaultPartName else string}" + } + .mkString + + def newWriter = { + val newFileSinkDesc = new FileSinkDesc( + fileSinkConf.getDirName + dynamicPartPath, + fileSinkConf.getTableInfo, + fileSinkConf.getCompressed) + newFileSinkDesc.setCompressCodec(fileSinkConf.getCompressCodec) + newFileSinkDesc.setCompressType(fileSinkConf.getCompressType) + + val path = { + val outputPath = FileOutputFormat.getOutputPath(conf.value) + assert(outputPath != null, "Undefined job output-path") + val workPath = new Path(outputPath, dynamicPartPath.stripPrefix("/")) + new Path(workPath, getOutputName) + } + + HiveFileFormatUtils.getHiveRecordWriter( + conf.value, + fileSinkConf.getTableInfo, + conf.value.getOutputValueClass.asInstanceOf[Class[Writable]], + newFileSinkDesc, + path, + Reporter.NULL) + } + + writers.getOrElseUpdate(dynamicPartPath, newWriter) + } +} diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-0-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 b/sql/hive/src/test/resources/golden/dynamic_partition-1-640552dd462707563fd255a713f83b41 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/dynamic_partition-2-36456c9d0d2e3ef72ab5ba9ba48e5493 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f b/sql/hive/src/test/resources/golden/dynamic_partition-3-b7f7fa7ebf666f4fee27e149d8c6961f new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a b/sql/hive/src/test/resources/golden/dynamic_partition-4-8bdb71ad8cb3cc3026043def2525de3a new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e b/sql/hive/src/test/resources/golden/dynamic_partition-5-c630dce438f3792e7fb0f523fbbb3e1e new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf b/sql/hive/src/test/resources/golden/dynamic_partition-6-7abc9ec8a36cdc5e89e955265a7fd7cf new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 b/sql/hive/src/test/resources/golden/dynamic_partition-7-be33aaa7253c8f248ff3921cd7dae340 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 2da8a6fac3d99..5d743a51b47c5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.hive.execution import scala.util.Try +import org.apache.hadoop.hive.conf.HiveConf.ConfVars + +import org.apache.spark.SparkException import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ @@ -380,7 +383,7 @@ class HiveQuerySuite extends HiveComparisonTest { def isExplanation(result: SchemaRDD) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } - explanation.exists(_ == "== Physical Plan ==") + explanation.contains("== Physical Plan ==") } test("SPARK-1704: Explain commands as a SchemaRDD") { @@ -568,6 +571,91 @@ class HiveQuerySuite extends HiveComparisonTest { case class LogEntry(filename: String, message: String) case class LogFile(name: String) + createQueryTest("dynamic_partition", + """ + |DROP TABLE IF EXISTS dynamic_part_table; + |CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT); + | + |SET hive.exec.dynamic.partition.mode=nonstrict; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, 1 FROM src WHERE key=150; + | + |INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, 1, NULL FROM src WHERE key=150; + | + |INSERT INTO TABLe dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT 1, NULL, NULL FROM src WHERE key=150; + | + |DROP TABLE IF EXISTS dynamic_part_table; + """.stripMargin) + + test("Dynamic partition folder layout") { + sql("DROP TABLE IF EXISTS dynamic_part_table") + sql("CREATE TABLE dynamic_part_table(intcol INT) PARTITIONED BY (partcol1 INT, partcol2 INT)") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + val data = Map( + Seq("1", "1") -> 1, + Seq("1", "NULL") -> 2, + Seq("NULL", "1") -> 3, + Seq("NULL", "NULL") -> 4) + + data.foreach { case (parts, value) => + sql( + s"""INSERT INTO TABLE dynamic_part_table PARTITION(partcol1, partcol2) + |SELECT $value, ${parts.mkString(", ")} FROM src WHERE key=150 + """.stripMargin) + + val partFolder = Seq("partcol1", "partcol2") + .zip(parts) + .map { case (k, v) => + if (v == "NULL") { + s"$k=${ConfVars.DEFAULTPARTITIONNAME.defaultVal}" + } else { + s"$k=$v" + } + } + .mkString("/") + + // Loads partition data to a temporary table to verify contents + val path = s"$warehousePath/dynamic_part_table/$partFolder/part-00000" + + sql("DROP TABLE IF EXISTS dp_verify") + sql("CREATE TABLE dp_verify(intcol INT)") + sql(s"LOAD DATA LOCAL INPATH '$path' INTO TABLE dp_verify") + + assert(sql("SELECT * FROM dp_verify").collect() === Array(Row(value))) + } + } + + test("Partition spec validation") { + sql("DROP TABLE IF EXISTS dp_test") + sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)") + sql("SET hive.exec.dynamic.partition.mode=strict") + + // Should throw when using strict dynamic partition mode without any static partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + + // Should throw when a static partition appears after a dynamic partition + intercept[SparkException] { + sql( + """INSERT INTO TABLE dp_test PARTITION(dp, sp = 1) + |SELECT key, value, key % 5 FROM src + """.stripMargin) + } + } + test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") { sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs") sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles") @@ -625,27 +713,27 @@ class HiveQuerySuite extends HiveComparisonTest { assert(sql("SET").collect().size == 0) assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey=$testVal")) + collectResults(sql(s"SET $testKey=$testVal")) } assert(hiveconf.get(testKey, "") == testVal) assertResult(Set(testKey -> testVal)) { - collectResults(hql("SET")) + collectResults(sql("SET")) } sql(s"SET ${testKey + testKey}=${testVal + testVal}") assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) assertResult(Set(testKey -> testVal, (testKey + testKey) -> (testVal + testVal))) { - collectResults(hql("SET")) + collectResults(sql("SET")) } // "set key" assertResult(Set(testKey -> testVal)) { - collectResults(hql(s"SET $testKey")) + collectResults(sql(s"SET $testKey")) } assertResult(Set(nonexistentKey -> "")) { - collectResults(hql(s"SET $nonexistentKey")) + collectResults(sql(s"SET $nonexistentKey")) } // Assert that sql() should have the same effects as sql() by repeating the above using sql(). From 51229ff7f4d3517706a1cdc1a2943ede1c605089 Mon Sep 17 00:00:00 2001 From: yingjieMiao Date: Mon, 29 Sep 2014 18:01:27 -0700 Subject: [PATCH 45/50] [graphX] GraphOps: random pick vertex bug When `numVertices > 50`, probability is set to 0. This would cause infinite loop. Author: yingjieMiao Closes #2553 from yingjieMiao/graphx and squashes the following commits: 6adf3c8 [yingjieMiao] [graphX] GraphOps: random pick vertex bug --- graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 02afaa987d40d..d0dd45dba618e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -254,7 +254,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali * Picks a random vertex from the graph and returns its ID. */ def pickRandomVertex(): VertexId = { - val probability = 50 / graph.numVertices + val probability = 50.0 / graph.numVertices var found = false var retVal: VertexId = null.asInstanceOf[VertexId] while (!found) { From dc30e4504abcda1774f5f09a08bba73d29a2898b Mon Sep 17 00:00:00 2001 From: oded Date: Mon, 29 Sep 2014 18:05:53 -0700 Subject: [PATCH 46/50] Fixed the condition in StronglyConnectedComponents Issue: SPARK-3635 Author: oded Closes #2486 from odedz/master and squashes the following commits: dd7890a [oded] Fixed the condition in StronglyConnectedComponents Issue: SPARK-3635 --- .../apache/spark/graphx/lib/StronglyConnectedComponents.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala index 46da38eeb725a..8dd958033b338 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala @@ -75,7 +75,7 @@ object StronglyConnectedComponents { sccWorkGraph, Long.MaxValue, activeDirection = EdgeDirection.Out)( (vid, myScc, neighborScc) => (math.min(myScc._1, neighborScc), myScc._2), e => { - if (e.srcId < e.dstId) { + if (e.srcAttr._1 < e.dstAttr._1) { Iterator((e.dstId, e.srcAttr._1)) } else { Iterator() From 210404a56197ad347f1e621ed53ef01327fba2bd Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 21:53:21 -0700 Subject: [PATCH 47/50] Minor cleanup of code. Author: Reynold Xin Closes #2581 from rxin/minor-cleanup and squashes the following commits: 736a91b [Reynold Xin] Minor cleanup of code. --- .../apache/spark/scheduler/JobLogger.scala | 17 +----- .../org/apache/spark/util/JsonProtocol.scala | 1 - .../scala/org/apache/spark/util/Utils.scala | 60 +++++++++---------- 3 files changed, 31 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index ceb434feb6ca1..54904bffdf10b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -20,15 +20,12 @@ package org.apache.spark.scheduler import java.io.{File, FileNotFoundException, IOException, PrintWriter} import java.text.SimpleDateFormat import java.util.{Date, Properties} -import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.HashMap import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.{DataReadMethod, TaskMetrics} -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel +import org.apache.spark.executor.TaskMetrics /** * :: DeveloperApi :: @@ -62,24 +59,16 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener private val dateFormat = new ThreadLocal[SimpleDateFormat]() { override def initialValue() = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") } - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent] createLogDir() - // The following 5 functions are used only in testing. - private[scheduler] def getLogDir = logDir - private[scheduler] def getJobIdToPrintWriter = jobIdToPrintWriter - private[scheduler] def getStageIdToJobId = stageIdToJobId - private[scheduler] def getJobIdToStageIds = jobIdToStageIds - private[scheduler] def getEventQueue = eventQueue - /** Create a folder for log files, the folder's name is the creation time of jobLogger */ protected def createLogDir() { val dir = new File(logDir + "/" + logDirName + "/") if (dir.exists()) { return } - if (dir.mkdirs() == false) { + if (!dir.mkdirs()) { // JobLogger should throw a exception rather than continue to construct this object. throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/") } @@ -261,7 +250,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener protected def recordJobProperties(jobId: Int, properties: Properties) { if (properties != null) { val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") - jobLogInfo(jobId, description, false) + jobLogInfo(jobId, description, withTime = false) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 6a48f673c4e78..5b2e7d3a7edb9 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -25,7 +25,6 @@ import scala.collection.Map import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.JsonAST._ -import org.json4s.jackson.JsonMethods._ import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleReadMetrics, diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 10d440828e323..dbe0cfa2b8ff9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -23,8 +23,6 @@ import java.nio.ByteBuffer import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor} -import org.apache.log4j.PropertyConfigurator - import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer @@ -37,12 +35,12 @@ import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration +import org.apache.log4j.PropertyConfigurator import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} @@ -86,7 +84,7 @@ private[spark] object Utils extends Logging { ois.readObject.asInstanceOf[T] } - /** Deserialize a Long value (used for {@link org.apache.spark.api.python.PythonPartitioner}) */ + /** Deserialize a Long value (used for [[org.apache.spark.api.python.PythonPartitioner]]) */ def deserializeLongValue(bytes: Array[Byte]) : Long = { // Note: we assume that we are given a Long value encoded in network (big-endian) byte order var result = bytes(7) & 0xFFL @@ -153,7 +151,7 @@ private[spark] object Utils extends Logging { def classForName(className: String) = Class.forName(className, true, getContextOrSparkClassLoader) /** - * Primitive often used when writing {@link java.nio.ByteBuffer} to {@link java.io.DataOutput}. + * Primitive often used when writing [[java.nio.ByteBuffer]] to [[java.io.DataOutput]] */ def writeByteBuffer(bb: ByteBuffer, out: ObjectOutput) = { if (bb.hasArray) { @@ -333,7 +331,7 @@ private[spark] object Utils extends Logging { val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) - val fileOverwrite = conf.getBoolean("spark.files.overwrite", false) + val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) uri.getScheme match { case "http" | "https" | "ftp" => logInfo("Fetching " + url + " to " + tempFile) @@ -355,7 +353,7 @@ private[spark] object Utils extends Logging { uc.connect() val in = uc.getInputStream() val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) + Utils.copyStream(in, out, closeStreams = true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { if (fileOverwrite) { targetFile.delete() @@ -402,7 +400,7 @@ private[spark] object Utils extends Logging { val fs = getHadoopFileSystem(uri, hadoopConf) val in = fs.open(new Path(uri)) val out = new FileOutputStream(tempFile) - Utils.copyStream(in, out, true) + Utils.copyStream(in, out, closeStreams = true) if (targetFile.exists && !Files.equal(tempFile, targetFile)) { if (fileOverwrite) { targetFile.delete() @@ -666,7 +664,7 @@ private[spark] object Utils extends Logging { */ def deleteRecursively(file: File) { if (file != null) { - if ((file.isDirectory) && !isSymlink(file)) { + if (file.isDirectory() && !isSymlink(file)) { for (child <- listFilesSafely(file)) { deleteRecursively(child) } @@ -701,11 +699,7 @@ private[spark] object Utils extends Logging { new File(file.getParentFile().getCanonicalFile(), file.getName()) } - if (fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile())) { - return false - } else { - return true - } + !fileInCanonicalDir.getCanonicalFile().equals(fileInCanonicalDir.getAbsoluteFile()) } /** @@ -804,7 +798,7 @@ private[spark] object Utils extends Logging { .start() new Thread("read stdout for " + command(0)) { override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { + for (line <- Source.fromInputStream(process.getInputStream).getLines()) { System.err.println(line) } } @@ -818,8 +812,10 @@ private[spark] object Utils extends Logging { /** * Execute a command and get its output, throwing an exception if it yields a code other than 0. */ - def executeAndGetOutput(command: Seq[String], workingDir: File = new File("."), - extraEnvironment: Map[String, String] = Map.empty): String = { + def executeAndGetOutput( + command: Seq[String], + workingDir: File = new File("."), + extraEnvironment: Map[String, String] = Map.empty): String = { val builder = new ProcessBuilder(command: _*) .directory(workingDir) val environment = builder.environment() @@ -829,7 +825,7 @@ private[spark] object Utils extends Logging { val process = builder.start() new Thread("read stderr for " + command(0)) { override def run() { - for (line <- Source.fromInputStream(process.getErrorStream).getLines) { + for (line <- Source.fromInputStream(process.getErrorStream).getLines()) { System.err.println(line) } } @@ -837,7 +833,7 @@ private[spark] object Utils extends Logging { val output = new StringBuffer val stdoutThread = new Thread("read stdout for " + command(0)) { override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines) { + for (line <- Source.fromInputStream(process.getInputStream).getLines()) { output.append(line) } } @@ -846,8 +842,8 @@ private[spark] object Utils extends Logging { val exitCode = process.waitFor() stdoutThread.join() // Wait for it to finish reading output if (exitCode != 0) { - logError(s"Process $command exited with code $exitCode: ${output}") - throw new SparkException("Process " + command + " exited with code " + exitCode) + logError(s"Process $command exited with code $exitCode: $output") + throw new SparkException(s"Process $command exited with code $exitCode") } output.toString } @@ -860,6 +856,7 @@ private[spark] object Utils extends Logging { try { block } catch { + case e: ControlThrowable => throw e case t: Throwable => ExecutorUncaughtExceptionHandler.uncaughtException(t) } } @@ -884,13 +881,12 @@ private[spark] object Utils extends Logging { * @param skipClass Function that is used to exclude non-user-code classes. */ def getCallSite(skipClass: String => Boolean = coreExclusionFunction): CallSite = { - val trace = Thread.currentThread.getStackTrace() - .filterNot { ste:StackTraceElement => - // When running under some profilers, the current stack trace might contain some bogus - // frames. This is intended to ensure that we don't crash in these situations by - // ignoring any frames that we can't examine. - (ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace")) - } + val trace = Thread.currentThread.getStackTrace().filterNot { ste: StackTraceElement => + // When running under some profilers, the current stack trace might contain some bogus + // frames. This is intended to ensure that we don't crash in these situations by + // ignoring any frames that we can't examine. + ste == null || ste.getMethodName == null || ste.getMethodName.contains("getStackTrace") + } // Keep crawling up the stack trace until we find the first function not inside of the spark // package. We track the last (shallowest) contiguous Spark method. This might be an RDD @@ -924,7 +920,7 @@ private[spark] object Utils extends Logging { } val callStackDepth = System.getProperty("spark.callstack.depth", "20").toInt CallSite( - shortForm = "%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine), + shortForm = s"$lastSparkMethod at $firstUserFile:$firstUserLine", longForm = callStack.take(callStackDepth).mkString("\n")) } @@ -1027,7 +1023,7 @@ private[spark] object Utils extends Logging { false } - def isSpace(c: Char): Boolean = { + private def isSpace(c: Char): Boolean = { " \t\r\n".indexOf(c) != -1 } @@ -1179,7 +1175,7 @@ private[spark] object Utils extends Logging { } import scala.sys.process._ (linkCmd + src.getAbsolutePath() + " " + dst.getPath() + cmdSuffix) lines_! - ProcessLogger(line => (logInfo(line))) + ProcessLogger(line => logInfo(line)) } @@ -1260,7 +1256,7 @@ private[spark] object Utils extends Logging { val startTime = System.currentTimeMillis while (!terminated) { try { - process.exitValue + process.exitValue() terminated = true } catch { case e: IllegalThreadStateException => From 6b79bfb42580b6bd4c4cd99fb521534a94150693 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 22:56:22 -0700 Subject: [PATCH 48/50] [SPARK-3613] Record only average block size in MapStatus for large stages This changes the way we send MapStatus from executors back to driver for large stages (>2000 tasks). For large stages, we no longer send one byte per block. Instead, we just send the average block size. This makes large jobs (tens of thousands of tasks) much more reliable since the driver no longer sends huge amount of data. Author: Reynold Xin Closes #2470 from rxin/mapstatus and squashes the following commits: 822ff54 [Reynold Xin] Code review feedback. 3b86f56 [Reynold Xin] Added MimaExclude. f89d182 [Reynold Xin] Fixed a bug in MapStatus 6a0401c [Reynold Xin] [SPARK-3613] Record only average block size in MapStatus for large stages. --- .../org/apache/spark/MapOutputTracker.scala | 29 +---- .../apache/spark/scheduler/MapStatus.scala | 119 ++++++++++++++++-- .../shuffle/hash/HashShuffleWriter.scala | 8 +- .../shuffle/sort/SortShuffleWriter.scala | 3 +- .../apache/spark/MapOutputTrackerSuite.scala | 66 ++++------ .../spark/scheduler/DAGSchedulerSuite.scala | 2 +- .../spark/scheduler/MapStatusSuite.scala | 92 ++++++++++++++ .../apache/spark/util/AkkaUtilsSuite.scala | 14 +-- project/MimaExcludes.scala | 5 +- 9 files changed, 240 insertions(+), 98 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index f92189b707fb5..4cb0bd4142435 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -349,7 +349,6 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } private[spark] object MapOutputTracker { - private val LOG_BASE = 1.1 // Serialize an array of map output locations into an efficient byte format so that we can send // it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will @@ -385,34 +384,8 @@ private[spark] object MapOutputTracker { throw new MetadataFetchFailedException( shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) } else { - (status.location, decompressSize(status.compressedSizes(reduceId))) + (status.location, status.getSizeForBlock(reduceId)) } } } - - /** - * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. - * We do this by encoding the log base 1.1 of the size as an integer, which can support - * sizes up to 35 GB with at most 10% error. - */ - def compressSize(size: Long): Byte = { - if (size == 0) { - 0 - } else if (size <= 1L) { - 1 - } else { - math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte - } - } - - /** - * Decompress an 8-bit encoded block size, using the reverse operation of compressSize. - */ - def decompressSize(compressedSize: Byte): Long = { - if (compressedSize == 0) { - 0 - } else { - math.pow(LOG_BASE, compressedSize & 0xFF).toLong - } - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index d3f63ff92ac6f..e25096ea92d70 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -24,22 +24,123 @@ import org.apache.spark.storage.BlockManagerId /** * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the * task ran on as well as the sizes of outputs for each reducer, for passing on to the reduce tasks. - * The map output sizes are compressed using MapOutputTracker.compressSize. */ -private[spark] class MapStatus(var location: BlockManagerId, var compressedSizes: Array[Byte]) - extends Externalizable { +private[spark] sealed trait MapStatus { + /** Location where this task was run. */ + def location: BlockManagerId - def this() = this(null, null) // For deserialization only + /** Estimated size for the reduce block, in bytes. */ + def getSizeForBlock(reduceId: Int): Long +} + + +private[spark] object MapStatus { + + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + if (uncompressedSizes.length > 2000) { + new HighlyCompressedMapStatus(loc, uncompressedSizes) + } else { + new CompressedMapStatus(loc, uncompressedSizes) + } + } + + private[this] val LOG_BASE = 1.1 + + /** + * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. + * We do this by encoding the log base 1.1 of the size as an integer, which can support + * sizes up to 35 GB with at most 10% error. + */ + def compressSize(size: Long): Byte = { + if (size == 0) { + 0 + } else if (size <= 1L) { + 1 + } else { + math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte + } + } + + /** + * Decompress an 8-bit encoded block size, using the reverse operation of compressSize. + */ + def decompressSize(compressedSize: Byte): Long = { + if (compressedSize == 0) { + 0 + } else { + math.pow(LOG_BASE, compressedSize & 0xFF).toLong + } + } +} + + +/** + * A [[MapStatus]] implementation that tracks the size of each block. Size for each block is + * represented using a single byte. + * + * @param loc location where the task is being executed. + * @param compressedSizes size of the blocks, indexed by reduce partition id. + */ +private[spark] class CompressedMapStatus( + private[this] var loc: BlockManagerId, + private[this] var compressedSizes: Array[Byte]) + extends MapStatus with Externalizable { + + protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + + def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize)) + } - def writeExternal(out: ObjectOutput) { - location.writeExternal(out) + override def location: BlockManagerId = loc + + override def getSizeForBlock(reduceId: Int): Long = { + MapStatus.decompressSize(compressedSizes(reduceId)) + } + + override def writeExternal(out: ObjectOutput): Unit = { + loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) } - def readExternal(in: ObjectInput) { - location = BlockManagerId(in) - compressedSizes = new Array[Byte](in.readInt()) + override def readExternal(in: ObjectInput): Unit = { + loc = BlockManagerId(in) + val len = in.readInt() + compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) } } + + +/** + * A [[MapStatus]] implementation that only stores the average size of the blocks. + * + * @param loc location where the task is being executed. + * @param avgSize average size of all the blocks + */ +private[spark] class HighlyCompressedMapStatus( + private[this] var loc: BlockManagerId, + private[this] var avgSize: Long) + extends MapStatus with Externalizable { + + def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { + this(loc, uncompressedSizes.sum / uncompressedSizes.length) + } + + protected def this() = this(null, 0L) // For deserialization only + + override def location: BlockManagerId = loc + + override def getSizeForBlock(reduceId: Int): Long = avgSize + + override def writeExternal(out: ObjectOutput): Unit = { + loc.writeExternal(out) + out.writeLong(avgSize) + } + + override def readExternal(in: ObjectInput): Unit = { + loc = BlockManagerId(in) + avgSize = in.readLong() + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index 4b9454d75abb7..746ed33b54c00 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -103,13 +103,11 @@ private[spark] class HashShuffleWriter[K, V]( private def commitWritesAndBuildStatus(): MapStatus = { // Commit the writes. Get the size of each bucket block (total block size). - val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter => + val sizes: Array[Long] = shuffle.writers.map { writer: BlockObjectWriter => writer.commitAndClose() - val size = writer.fileSegment().length - MapOutputTracker.compressSize(size) + writer.fileSegment().length } - - new MapStatus(blockManager.blockManagerId, compressedSizes) + MapStatus(blockManager.blockManagerId, sizes) } private def revertWrites(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 89a78d6982ba0..927481b72cf4f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,8 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C]( val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile) shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths) - mapStatus = new MapStatus(blockManager.blockManagerId, - partitionLengths.map(MapOutputTracker.compressSize)) + mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths) } /** Close this writer, passing along whether the map completed */ diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 5369169811f81..1fef79ad1001f 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -23,32 +23,13 @@ import akka.actor._ import akka.testkit.TestActorRef import org.scalatest.FunSuite -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { private val conf = new SparkConf - test("compressSize") { - assert(MapOutputTracker.compressSize(0L) === 0) - assert(MapOutputTracker.compressSize(1L) === 1) - assert(MapOutputTracker.compressSize(2L) === 8) - assert(MapOutputTracker.compressSize(10L) === 25) - assert((MapOutputTracker.compressSize(1000000L) & 0xFF) === 145) - assert((MapOutputTracker.compressSize(1000000000L) & 0xFF) === 218) - // This last size is bigger than we can encode in a byte, so check that we just return 255 - assert((MapOutputTracker.compressSize(1000000000000000000L) & 0xFF) === 255) - } - - test("decompressSize") { - assert(MapOutputTracker.decompressSize(0) === 0) - for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { - val size2 = MapOutputTracker.decompressSize(MapOutputTracker.compressSize(size)) - assert(size2 >= 0.99 * size && size2 <= 1.11 * size, - "size " + size + " decompressed to " + size2 + ", which is out of range") - } - } test("master start and stop") { val actorSystem = ActorSystem("test") @@ -65,14 +46,12 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) assert(tracker.containsShuffle(10)) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - val size10000 = MapOutputTracker.decompressSize(compressedSize10000) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), + Array(1000L, 10000L))) + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), + Array(10000L, 1000L))) val statuses = tracker.getServerStatuses(10, 0) assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) @@ -84,11 +63,11 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), + val compressedSize1000 = MapStatus.compressSize(1000L) + val compressedSize10000 = MapStatus.compressSize(10000L) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize10000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getServerStatuses(10, 0).nonEmpty) @@ -103,11 +82,11 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) tracker.registerShuffle(10, 2) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val compressedSize10000 = MapOutputTracker.compressSize(10000L) - tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000), + val compressedSize1000 = MapStatus.compressSize(1000L) + val compressedSize10000 = MapStatus.compressSize(10000L) + tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), Array(compressedSize1000, compressedSize1000, compressedSize1000))) - tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000), + tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000, compressedSize1000))) // As if we had two simultaneous fetch failures @@ -142,10 +121,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { slaveTracker.updateEpoch(masterTracker.getEpoch) intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, MapStatus( + BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getServerStatuses(10, 0).toSeq === @@ -173,8 +151,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Byte](10)(0))) + masterTracker.registerMapOutput(10, 0, MapStatus( + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) masterActor.receive(GetMapOutputStatuses(10)) } @@ -194,8 +172,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { // being sent. masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => - masterTracker.registerMapOutput(20, i, new MapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Byte](4000000)(0))) + masterTracker.registerMapOutput(20, i, new CompressedMapStatus( + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index aa73469b6acd8..a2e4f712db55b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -740,7 +740,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F } private def makeMapStatus(host: String, reduces: Int): MapStatus = - new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2)) + MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(2)) private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345) diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala new file mode 100644 index 0000000000000..79e04f046e4c4 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite + +import org.apache.spark.SparkConf +import org.apache.spark.serializer.JavaSerializer + + +class MapStatusSuite extends FunSuite { + + test("compressSize") { + assert(MapStatus.compressSize(0L) === 0) + assert(MapStatus.compressSize(1L) === 1) + assert(MapStatus.compressSize(2L) === 8) + assert(MapStatus.compressSize(10L) === 25) + assert((MapStatus.compressSize(1000000L) & 0xFF) === 145) + assert((MapStatus.compressSize(1000000000L) & 0xFF) === 218) + // This last size is bigger than we can encode in a byte, so check that we just return 255 + assert((MapStatus.compressSize(1000000000000000000L) & 0xFF) === 255) + } + + test("decompressSize") { + assert(MapStatus.decompressSize(0) === 0) + for (size <- Seq(2L, 10L, 100L, 50000L, 1000000L, 1000000000L)) { + val size2 = MapStatus.decompressSize(MapStatus.compressSize(size)) + assert(size2 >= 0.99 * size && size2 <= 1.11 * size, + "size " + size + " decompressed to " + size2 + ", which is out of range") + } + } + + test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) { + val sizes = Array.fill[Long](2001)(150L) + val status = MapStatus(null, sizes) + assert(status.isInstanceOf[HighlyCompressedMapStatus]) + assert(status.getSizeForBlock(10) === 150L) + assert(status.getSizeForBlock(50) === 150L) + assert(status.getSizeForBlock(99) === 150L) + assert(status.getSizeForBlock(2000) === 150L) + } + + test(classOf[HighlyCompressedMapStatus].getName + ": estimated size is within 10%") { + val sizes = Array.tabulate[Long](50) { i => i.toLong } + val loc = BlockManagerId("a", "b", 10) + val status = MapStatus(loc, sizes) + val ser = new JavaSerializer(new SparkConf) + val buf = ser.newInstance().serialize(status) + val status1 = ser.newInstance().deserialize[MapStatus](buf) + assert(status1.location == loc) + for (i <- 0 until sizes.length) { + // make sure the estimated size is within 10% of the input; note that we skip the very small + // sizes because the compression is very lossy there. + val estimate = status1.getSizeForBlock(i) + if (estimate > 100) { + assert(math.abs(estimate - sizes(i)) * 10 <= sizes(i), + s"incorrect estimated size $estimate, original was ${sizes(i)}") + } + } + } + + test(classOf[HighlyCompressedMapStatus].getName + ": estimated size should be the average size") { + val sizes = Array.tabulate[Long](3000) { i => i.toLong } + val avg = sizes.sum / sizes.length + val loc = BlockManagerId("a", "b", 10) + val status = MapStatus(loc, sizes) + val ser = new JavaSerializer(new SparkConf) + val buf = ser.newInstance().serialize(status) + val status1 = ser.newInstance().deserialize[MapStatus](buf) + assert(status1.location == loc) + for (i <- 0 until 3000) { + val estimate = status1.getSizeForBlock(i) + assert(estimate === avg) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 76bf4cfd11267..7bca1711ae226 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -106,10 +106,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, + MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) @@ -157,10 +156,9 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext { masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - val compressedSize1000 = MapOutputTracker.compressSize(1000L) - val size1000 = MapOutputTracker.decompressSize(compressedSize1000) - masterTracker.registerMapOutput(10, 0, new MapStatus( - BlockManagerId("a", "hostA", 1000), Array(compressedSize1000))) + val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) + masterTracker.registerMapOutput(10, 0, MapStatus( + BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 1adfaa18c6202..4076ebc6fc8d5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -45,7 +45,10 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"), ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2") + "org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL2"), + // MapStatus should be private[spark] + ProblemFilters.exclude[IncompatibleTemplateDefProblem]( + "org.apache.spark.scheduler.MapStatus") ) case v if v.startsWith("1.1") => From de700d31778eb68807183cf32be8034abdc0120e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Sep 2014 23:17:53 -0700 Subject: [PATCH 49/50] [SPARK-3709] Executors don't always report broadcast block removal properly back to the driver The problem was that the 2nd argument in RemoveBroadcast is not tellMaster! It is "removeFromDriver". Basically when removeFromDriver is not true, we don't report broadcast block removal back to the driver, and then other executors mistakenly think that the executor would still have the block, and try to fetch from it. cc @tdas Author: Reynold Xin Closes #2588 from rxin/debug and squashes the following commits: 6dab2e3 [Reynold Xin] Don't log random messages. f430686 [Reynold Xin] Always report broadcast removal back to master. 2a13f70 [Reynold Xin] iii --- .../apache/spark/network/nio/NioBlockTransferService.scala | 2 +- .../org/apache/spark/storage/BlockManagerSlaveActor.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala index 59958ee894230..b389b9a2022c6 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala @@ -200,6 +200,6 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa val buffer = blockDataManager.getBlockData(blockId).orNull logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs) + " and got buffer " + buffer) - buffer.nioByteBuffer() + if (buffer == null) null else buffer.nioByteBuffer() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 14ae2f38c5670..8462871e798a5 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -58,9 +58,9 @@ class BlockManagerSlaveActor( SparkEnv.get.shuffleManager.unregisterShuffle(shuffleId) } - case RemoveBroadcast(broadcastId, tellMaster) => + case RemoveBroadcast(broadcastId, _) => doAsync[Int]("removing broadcast " + broadcastId, sender) { - blockManager.removeBroadcast(broadcastId, tellMaster) + blockManager.removeBroadcast(broadcastId, tellMaster = true) } case GetBlockStatus(blockId, _) => From b167a8c7e75d9e816784bd655bce1feb6c447210 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Sep 2014 23:36:10 -0700 Subject: [PATCH 50/50] [SPARK-3734] DriverRunner should not read SPARK_HOME from submitter's environment When using spark-submit in `cluster` mode to submit a job to a Spark Standalone cluster, if the JAVA_HOME environment variable was set on the submitting machine then DriverRunner would attempt to use the submitter's JAVA_HOME to launch the driver process (instead of the worker's JAVA_HOME), causing the driver to fail unless the submitter and worker had the same Java location. This commit fixes this by reading JAVA_HOME from sys.env instead of command.environment. Author: Josh Rosen Closes #2586 from JoshRosen/SPARK-3734 and squashes the following commits: e9513d9 [Josh Rosen] [SPARK-3734] DriverRunner should not read SPARK_HOME from submitter's environment. --- .../scala/org/apache/spark/deploy/worker/CommandUtils.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index 12e98fd40d6c9..2e9be2a180c68 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.Utils private[spark] object CommandUtils extends Logging { def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = { - val runner = getEnv("JAVA_HOME", command).map(_ + "/bin/java").getOrElse("java") + val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java") // SPARK-698: do not call the run.cmd script, as process.destroy() // fails to kill a process tree on Windows @@ -38,9 +38,6 @@ object CommandUtils extends Logging { command.arguments } - private def getEnv(key: String, command: Command): Option[String] = - command.environment.get(key).orElse(Option(System.getenv(key))) - /** * Attention: this must always be aligned with the environment variables in the run scripts and * the way the JAVA_OPTS are assembled there.