From e09e18b3123c20e9b9497cf606473da500349d4d Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 2 Aug 2014 12:11:50 -0700 Subject: [PATCH 01/13] [HOTFIX] Do not throw NPE if spark.test.home is not set `spark.test.home` was introduced in #1734. This is fine for SBT but is failing maven tests. Either way it shouldn't throw an NPE. Author: Andrew Or Closes #1739 from andrewor14/fix-spark-test-home and squashes the following commits: ce2624c [Andrew Or] Do not throw NPE if spark.test.home is not set --- .../scala/org/apache/spark/deploy/worker/Worker.scala | 9 +++++++-- core/src/test/scala/org/apache/spark/DriverSuite.scala | 2 +- .../scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 2 +- .../apache/spark/deploy/worker/ExecutorRunnerTest.scala | 2 +- pom.xml | 8 ++++---- 5 files changed, 14 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index c6ea42fceb659..458d9947bd873 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -71,7 +71,7 @@ private[spark] class Worker( // TTL for app folders/data; after TTL expires it will be cleaned up val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) - + val testing: Boolean = sys.props.contains("spark.testing") val masterLock: Object = new Object() var master: ActorSelection = null var masterAddress: Address = null @@ -82,7 +82,12 @@ private[spark] class Worker( @volatile var connected = false val workerId = generateWorkerId() val sparkHome = - new File(sys.props.get("spark.test.home").orElse(sys.env.get("SPARK_HOME")).getOrElse(".")) + if (testing) { + assert(sys.props.contains("spark.test.home"), "spark.test.home is not set!") + new File(sys.props("spark.test.home")) + } else { + new File(sys.env.get("SPARK_HOME").getOrElse(".")) + } var workDir: File = null val executors = new HashMap[String, ExecutorRunner] val finishedExecutors = new HashMap[String, ExecutorRunner] diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index e36902ec81e08..a73e1ef0288a5 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -34,7 +34,7 @@ import scala.language.postfixOps class DriverSuite extends FunSuite with Timeouts { test("driver should exit after finishing") { - val sparkHome = sys.props("spark.test.home") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 8126ef1bb23aa..a5cdcfb5de03b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -295,7 +295,7 @@ class SparkSubmitSuite extends FunSuite with Matchers { // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. def runSparkSubmit(args: Seq[String]): String = { - val sparkHome = sys.props("spark.test.home") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) Utils.executeAndGetOutput( Seq("./bin/spark-submit") ++ args, new File(sparkHome), diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 149a2b3d95b86..39ab53cf0b5b1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -27,7 +27,7 @@ import org.apache.spark.SparkConf class ExecutorRunnerTest extends FunSuite { test("command includes appId") { def f(s:String) = new File(s) - val sparkHome = sys.props("spark.test.home") + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) val appDesc = new ApplicationDescription("app name", Some(8), 500, Command("foo", Seq(), Map(), Seq(), Seq(), Seq()), "appUiUrl") val appId = "12345-worker321-9876" diff --git a/pom.xml b/pom.xml index ae97bf03c53a2..99ae4b8b33f94 100644 --- a/pom.xml +++ b/pom.xml @@ -868,10 +868,10 @@ ${project.build.directory}/SparkTestSuite.txt -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - - ${session.executionRootDirectory} - 1 - + + ${session.executionRootDirectory} + 1 + From 3f67382e7c9c3f6a8f6ce124ab3fcb1a9c1a264f Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Sat, 2 Aug 2014 13:07:17 -0700 Subject: [PATCH 02/13] [SPARK-2478] [mllib] DecisionTree Python API Added experimental Python API for Decision Trees. API: * class DecisionTreeModel ** predict() for single examples and RDDs, taking both feature vectors and LabeledPoints ** numNodes() ** depth() ** __str__() * class DecisionTree ** trainClassifier() ** trainRegressor() ** train() Examples and testing: * Added example testing classification and regression with batch prediction: examples/src/main/python/mllib/tree.py * Have also tested example usage in doc of python/pyspark/mllib/tree.py which tests single-example prediction with dense and sparse vectors Also: Small bug fix in python/pyspark/mllib/_common.py: In _linear_predictor_typecheck, changed check for RDD to use isinstance() instead of type() in order to catch RDD subclasses. CC mengxr manishamde Author: Joseph K. Bradley Closes #1727 from jkbradley/decisiontree-python-new and squashes the following commits: 3744488 [Joseph K. Bradley] Renamed test tree.py to decision_tree_runner.py Small updates based on github review. 6b86a9d [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new affceb9 [Joseph K. Bradley] * Fixed bug in doc tests in pyspark/mllib/util.py caused by change in loadLibSVMFile behavior. (It used to threshold labels at 0 to make them 0/1, but it now leaves them as they are.) * Fixed small bug in loadLibSVMFile: If a data file had no features, then loadLibSVMFile would create a single all-zero feature. 67a29bc [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new cf46ad7 [Joseph K. Bradley] Python DecisionTreeModel * predict(empty RDD) returns an empty RDD instead of an error. * Removed support for calling predict() on LabeledPoint and RDD[LabeledPoint] * predict() does not cache serialized RDD any more. aa29873 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new bf21be4 [Joseph K. Bradley] removed old run() func from DecisionTree fa10ea7 [Joseph K. Bradley] Small style update 7968692 [Joseph K. Bradley] small braces typo fix e34c263 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4801b40 [Joseph K. Bradley] Small style update to DecisionTreeSuite db0eab2 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix2' into decisiontree-python-new 6873fa9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. 93953f1 [Joseph K. Bradley] Likely done with Python API. 6df89a9 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 4562c08 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 665ba78 [Joseph K. Bradley] Small updates towards Python DecisionTree API 188cb0d [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 6622247 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new b8fac57 [Joseph K. Bradley] Finished Python DecisionTree API and example but need to test a bit more. 2b20c61 [Joseph K. Bradley] Small doc and style updates 1b29c13 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 584449a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new dab0b67 [Joseph K. Bradley] Added documentation for DecisionTree internals 8bb8aa0 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 978cfcf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 6eed482 [Joseph K. Bradley] In DecisionTree: Changed from using procedural syntax for functions returning Unit to explicitly writing Unit return type. 376dca2 [Joseph K. Bradley] Updated meaning of maxDepth by 1 to fit scikit-learn and rpart. * In code, replaced usages of maxDepth <-- maxDepth + 1 * In params, replace settings of maxDepth <-- maxDepth - 1 e06e423 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new bab3f19 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 59750f8 [Joseph K. Bradley] * Updated Strategy to check numClassesForClassification only if algo=Classification. * Updates based on comments: ** DecisionTreeRunner *** Made dataFormat arg default to libsvm ** Small cleanups ** tree.Node: Made recursive helper methods private, and renamed them. 52e17c5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix f5a036c [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new da50db7 [Joseph K. Bradley] Added one more test to DecisionTreeSuite: stump with 2 continuous variables for binary classification. Caused problems in past, but fixed now. 8e227ea [Joseph K. Bradley] Changed Strategy so it only requires numClassesForClassification >= 2 for classification cd1d933 [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 8ea8750 [Joseph K. Bradley] Bug fix: Off-by-1 when finding thresholds for splits for continuous features. 8a758db [Joseph K. Bradley] Merge branch 'decisiontree-bugfix' into decisiontree-python-new 5fe44ed [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-python-new 2283df8 [Joseph K. Bradley] 2 bug fixes. 73fbea2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into decisiontree-bugfix 5f920a1 [Joseph K. Bradley] Demonstration of bug before submitting fix: Updated DecisionTreeSuite so that 3 tests fail. Will describe bug in next commit. f825352 [Joseph K. Bradley] Wrote Python API and example for DecisionTree. Also added toString, depth, and numNodes methods to DecisionTreeModel. --- .../main/python/mllib/decision_tree_runner.py | 133 +++++++++++ .../main/python/mllib/logistic_regression.py | 4 +- .../mllib/api/python/PythonMLLibAPI.scala | 78 ++++++ .../mllib/tree/configuration/Strategy.scala | 3 +- .../spark/mllib/tree/DecisionTreeSuite.scala | 3 +- python/pyspark/mllib/_common.py | 33 ++- python/pyspark/mllib/tests.py | 36 +++ python/pyspark/mllib/tree.py | 225 ++++++++++++++++++ python/pyspark/mllib/util.py | 14 +- python/run-tests | 1 + 10 files changed, 509 insertions(+), 21 deletions(-) create mode 100755 examples/src/main/python/mllib/decision_tree_runner.py create mode 100644 python/pyspark/mllib/tree.py diff --git a/examples/src/main/python/mllib/decision_tree_runner.py b/examples/src/main/python/mllib/decision_tree_runner.py new file mode 100755 index 0000000000000..8efadb5223f56 --- /dev/null +++ b/examples/src/main/python/mllib/decision_tree_runner.py @@ -0,0 +1,133 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decision tree classification and regression using MLlib. +""" + +import numpy, os, sys + +from operator import add + +from pyspark import SparkContext +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.tree import DecisionTree +from pyspark.mllib.util import MLUtils + + +def getAccuracy(dtModel, data): + """ + Return accuracy of DecisionTreeModel on the given RDD[LabeledPoint]. + """ + seqOp = (lambda acc, x: acc + (x[0] == x[1])) + predictions = dtModel.predict(data.map(lambda x: x.features)) + truth = data.map(lambda p: p.label) + trainCorrect = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 + return trainCorrect / (0.0 + data.count()) + + +def getMSE(dtModel, data): + """ + Return mean squared error (MSE) of DecisionTreeModel on the given + RDD[LabeledPoint]. + """ + seqOp = (lambda acc, x: acc + numpy.square(x[0] - x[1])) + predictions = dtModel.predict(data.map(lambda x: x.features)) + truth = data.map(lambda p: p.label) + trainMSE = predictions.zip(truth).aggregate(0, seqOp, add) + if data.count() == 0: + return 0 + return trainMSE / (0.0 + data.count()) + + +def reindexClassLabels(data): + """ + Re-index class labels in a dataset to the range {0,...,numClasses-1}. + If all labels in that range already appear at least once, + then the returned RDD is the same one (without a mapping). + Note: If a label simply does not appear in the data, + the index will not include it. + Be aware of this when reindexing subsampled data. + :param data: RDD of LabeledPoint where labels are integer values + denoting labels for a classification problem. + :return: Pair (reindexedData, origToNewLabels) where + reindexedData is an RDD of LabeledPoint with labels in + the range {0,...,numClasses-1}, and + origToNewLabels is a dictionary mapping original labels + to new labels. + """ + # classCounts: class --> # examples in class + classCounts = data.map(lambda x: x.label).countByValue() + numExamples = sum(classCounts.values()) + sortedClasses = sorted(classCounts.keys()) + numClasses = len(classCounts) + # origToNewLabels: class --> index in 0,...,numClasses-1 + if (numClasses < 2): + print >> sys.stderr, \ + "Dataset for classification should have at least 2 classes." + \ + " The given dataset had only %d classes." % numClasses + exit(1) + origToNewLabels = dict([(sortedClasses[i], i) for i in range(0, numClasses)]) + + print "numClasses = %d" % numClasses + print "Per-class example fractions, counts:" + print "Class\tFrac\tCount" + for c in sortedClasses: + frac = classCounts[c] / (numExamples + 0.0) + print "%g\t%g\t%d" % (c, frac, classCounts[c]) + + if (sortedClasses[0] == 0 and sortedClasses[-1] == numClasses - 1): + return (data, origToNewLabels) + else: + reindexedData = \ + data.map(lambda x: LabeledPoint(origToNewLabels[x.label], x.features)) + return (reindexedData, origToNewLabels) + + +def usage(): + print >> sys.stderr, \ + "Usage: decision_tree_runner [libsvm format data filepath]\n" + \ + " Note: This only supports binary classification." + exit(1) + + +if __name__ == "__main__": + if len(sys.argv) > 2: + usage() + sc = SparkContext(appName="PythonDT") + + # Load data. + dataPath = 'data/mllib/sample_libsvm_data.txt' + if len(sys.argv) == 2: + dataPath = sys.argv[1] + if not os.path.isfile(dataPath): + usage() + points = MLUtils.loadLibSVMFile(sc, dataPath) + + # Re-index class labels if needed. + (reindexedData, origToNewLabels) = reindexClassLabels(points) + + # Train a classifier. + model = DecisionTree.trainClassifier(reindexedData, numClasses=2) + # Print learned tree and stats. + print "Trained DecisionTree for classification:" + print " Model numNodes: %d\n" % model.numNodes() + print " Model depth: %d\n" % model.depth() + print " Training accuracy: %g\n" % getAccuracy(model, reindexedData) + print model diff --git a/examples/src/main/python/mllib/logistic_regression.py b/examples/src/main/python/mllib/logistic_regression.py index 6e0f7a4ee5a81..9d547ff77c984 100755 --- a/examples/src/main/python/mllib/logistic_regression.py +++ b/examples/src/main/python/mllib/logistic_regression.py @@ -30,8 +30,10 @@ from pyspark.mllib.classification import LogisticRegressionWithSGD -# Parse a line of text into an MLlib LabeledPoint object def parsePoint(line): + """ + Parse a line of text into an MLlib LabeledPoint object. + """ values = [float(s) for s in line.split(' ')] if values[0] == -1: # Convert -1 labels to 0 for MLlib values[0] = 0 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 7d912737b8f0b..1d5d3762ed8e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.api.python import java.nio.{ByteBuffer, ByteOrder} +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.mllib.classification._ @@ -29,6 +31,11 @@ import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors} import org.apache.spark.mllib.random.{RandomRDDGenerators => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.DecisionTree +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} +import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.util.MLUtils @@ -472,6 +479,76 @@ class PythonMLLibAPI extends Serializable { ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha) } + /** + * Java stub for Python mllib DecisionTree.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + * @param dataBytesJRDD Training data + * @param categoricalFeaturesInfoJMap Categorical features info, as Java map + */ + def trainDecisionTreeModel( + dataBytesJRDD: JavaRDD[Array[Byte]], + algoStr: String, + numClasses: Int, + categoricalFeaturesInfoJMap: java.util.Map[Int, Int], + impurityStr: String, + maxDepth: Int, + maxBins: Int): DecisionTreeModel = { + + val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) + + val algo: Algo = algoStr match { + case "classification" => Classification + case "regression" => Regression + case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr") + } + val impurity: Impurity = impurityStr match { + case "gini" => Gini + case "entropy" => Entropy + case "variance" => Variance + case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr") + } + + val strategy = new Strategy( + algo = algo, + impurity = impurity, + maxDepth = maxDepth, + numClassesForClassification = numClasses, + maxBins = maxBins, + categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap) + + DecisionTree.train(data, strategy) + } + + /** + * Predict the label of the given data point. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param featuresBytes Serialized feature vector for data point + * @return predicted label + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + featuresBytes: Array[Byte]): Double = { + val features: Vector = deserializeDoubleVector(featuresBytes) + model.predict(features) + } + + /** + * Predict the labels of the given data points. + * This is a Java stub for python DecisionTreeModel.predict() + * + * @param dataJRDD A JavaRDD with serialized feature vectors + * @return JavaRDD of serialized predictions + */ + def predictDecisionTreeModel( + model: DecisionTreeModel, + dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = { + val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes)) + model.predict(data).map(serializeDouble) + } + /** * Java stub for mllib Statistics.corr(X: RDD[Vector], method: String). * Returns the correlation matrix serialized into a byte array understood by deserializers in @@ -597,4 +674,5 @@ class PythonMLLibAPI extends Serializable { val s = getSeedOrDefault(seed) RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 5c65b537b6867..fdad4f029aa99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -56,7 +56,8 @@ class Strategy ( if (algo == Classification) { require(numClassesForClassification >= 2) } - val isMulticlassClassification = numClassesForClassification > 2 + val isMulticlassClassification = + algo == Classification && numClassesForClassification > 2 val isMulticlassWithCategoricalFeatures = isMulticlassClassification && (categoricalFeaturesInfo.size > 0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 546a132559326..8665a00f3b356 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -48,7 +48,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext { requiredMSE: Double) { val predictions = input.map(x => model.predict(x.features)) val squaredError = predictions.zip(input).map { case (prediction, expected) => - (prediction - expected.label) * (prediction - expected.label) + val err = prediction - expected.label + err * err }.sum val mse = squaredError / input.length assert(mse <= requiredMSE) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index c6ca6a75df746..9c1565affbdac 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -343,22 +343,35 @@ def _copyto(array, buffer, offset, shape, dtype): temp_array[...] = array -def _get_unmangled_rdd(data, serializer): +def _get_unmangled_rdd(data, serializer, cache=True): + """ + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ dataBytes = data.map(serializer) dataBytes._bypass_serializer = True - dataBytes.cache() # TODO: users should unpersist() this later! + if cache: + dataBytes.cache() return dataBytes -# Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of -# _serialized_double_vectors -def _get_unmangled_double_vector_rdd(data): - return _get_unmangled_rdd(data, _serialize_double_vector) +def _get_unmangled_double_vector_rdd(data, cache=True): + """ + Map a pickled Python RDD of Python dense or sparse vectors to a Java RDD of + _serialized_double_vectors. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_double_vector, cache) -# Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points -def _get_unmangled_labeled_point_rdd(data): - return _get_unmangled_rdd(data, _serialize_labeled_point) +def _get_unmangled_labeled_point_rdd(data, cache=True): + """ + Map a pickled Python RDD of LabeledPoint to a Java RDD of _serialized_labeled_points. + :param cache: If True, the serialized RDD is cached. (default = True) + WARNING: Users should unpersist() this later! + """ + return _get_unmangled_rdd(data, _serialize_labeled_point, cache) # Common functions for dealing with and training linear models @@ -380,7 +393,7 @@ def _linear_predictor_typecheck(x, coeffs): if x.size != coeffs.shape[0]: raise RuntimeError("Got sparse vector of size %d; wanted %d" % ( x.size, coeffs.shape[0])) - elif (type(x) == RDD): + elif isinstance(x, RDD): raise RuntimeError("Bulk predict not yet supported.") else: raise TypeError("Argument of type " + type(x).__name__ + " unsupported") diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 37ccf1d590743..9d1e5be637a9a 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -100,6 +100,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -127,9 +128,19 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = \ + DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, [0, -1]), LabeledPoint(1.0, [0, 1]), @@ -157,6 +168,14 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = \ + DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): @@ -229,6 +248,7 @@ def test_clustering(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(0.0, self.scipy_matrix(2, {0: 1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -256,9 +276,18 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[2]) <= 0) self.assertTrue(nb_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories + dt_model = DecisionTree.trainClassifier(rdd, numClasses=2, + categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD + from pyspark.mllib.tree import DecisionTree data = [ LabeledPoint(-1.0, self.scipy_matrix(2, {1: -1.0})), LabeledPoint(1.0, self.scipy_matrix(2, {1: 1.0})), @@ -286,6 +315,13 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[2]) <= 0) self.assertTrue(rr_model.predict(features[3]) > 0) + categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories + dt_model = DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(dt_model.predict(features[0]) <= 0) + self.assertTrue(dt_model.predict(features[1]) > 0) + self.assertTrue(dt_model.predict(features[2]) <= 0) + self.assertTrue(dt_model.predict(features[3]) > 0) + if __name__ == "__main__": if not _have_scipy: diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py new file mode 100644 index 0000000000000..1e0006df75ac6 --- /dev/null +++ b/python/pyspark/mllib/tree.py @@ -0,0 +1,225 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from py4j.java_collections import MapConverter + +from pyspark import SparkContext, RDD +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, _serialize_double_vector, \ + _deserialize_labeled_point, _get_unmangled_labeled_point_rdd, \ + _deserialize_double +from pyspark.mllib.regression import LabeledPoint +from pyspark.serializers import NoOpSerializer + +class DecisionTreeModel(object): + """ + A decision tree model for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. + """ + + def __init__(self, sc, java_model): + """ + :param sc: Spark context + :param java_model: Handle to Java model object + """ + self._sc = sc + self._java_model = java_model + + def __del__(self): + self._sc._gateway.detach(self._java_model) + + def predict(self, x): + """ + Predict the label of one or more examples. + :param x: Data point (feature vector), + or an RDD of data points (feature vectors). + """ + pythonAPI = self._sc._jvm.PythonMLLibAPI() + if isinstance(x, RDD): + # Bulk prediction + if x.count() == 0: + return self._sc.parallelize([]) + dataBytes = _get_unmangled_double_vector_rdd(x, cache=False) + jSerializedPreds = \ + pythonAPI.predictDecisionTreeModel(self._java_model, + dataBytes._jrdd) + serializedPreds = RDD(jSerializedPreds, self._sc, NoOpSerializer()) + return serializedPreds.map(lambda bytes: _deserialize_double(bytearray(bytes))) + else: + # Assume x is a single data point. + x_ = _serialize_double_vector(x) + return pythonAPI.predictDecisionTreeModel(self._java_model, x_) + + def numNodes(self): + return self._java_model.numNodes() + + def depth(self): + return self._java_model.depth() + + def __str__(self): + return self._java_model.toString() + + +class DecisionTree(object): + """ + Learning algorithm for a decision tree model + for classification or regression. + + EXPERIMENTAL: This is an experimental API. + It will probably be modified for Spark v1.2. + + Example usage: + >>> from numpy import array, ndarray + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import DecisionTree + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(1.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2) + >>> print(model) + DecisionTreeModel classifier + If (feature 0 <= 0.5) + Predict: 0.0 + Else (feature 0 > 0.5) + Predict: 1.0 + + >>> model.predict(array([1.0])) > 0 + True + >>> model.predict(array([0.0])) == 0 + True + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data)) + >>> model.predict(array([0.0, 1.0])) == 1 + True + >>> model.predict(array([0.0, 0.0])) == 0 + True + >>> model.predict(SparseVector(2, {1: 1.0})) == 1 + True + >>> model.predict(SparseVector(2, {1: 0.0})) == 0 + True + """ + + @staticmethod + def trainClassifier(data, numClasses, categoricalFeaturesInfo={}, + impurity="gini", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for classification. + + :param data: Training data: RDD of LabeledPoint. + Labels are integers {0,1,...,numClasses}. + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: Supported values: "entropy" or "gini" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "classification", numClasses, + categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + @staticmethod + def trainRegressor(data, categoricalFeaturesInfo={}, + impurity="variance", maxDepth=4, maxBins=100): + """ + Train a DecisionTreeModel for regression. + + :param data: Training data: RDD of LabeledPoint. + Labels are real numbers. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: Supported values: "variance" + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + return DecisionTree.train(data, "regression", 0, + categoricalFeaturesInfo, + impurity, maxDepth, maxBins) + + + @staticmethod + def train(data, algo, numClasses, categoricalFeaturesInfo, + impurity, maxDepth, maxBins=100): + """ + Train a DecisionTreeModel for classification or regression. + + :param data: Training data: RDD of LabeledPoint. + For classification, labels are integers + {0,1,...,numClasses}. + For regression, labels are real numbers. + :param algo: "classification" or "regression" + :param numClasses: Number of classes for classification. + :param categoricalFeaturesInfo: Map from categorical feature index + to number of categories. + Any feature not in this map + is treated as continuous. + :param impurity: For classification: "entropy" or "gini". + For regression: "variance". + :param maxDepth: Max depth of tree. + E.g., depth 0 means 1 leaf node. + Depth 1 means 1 internal node + 2 leaf nodes. + :param maxBins: Number of bins used for finding splits at each node. + :return: DecisionTreeModel + """ + sc = data.context + dataBytes = _get_unmangled_labeled_point_rdd(data) + categoricalFeaturesInfoJMap = \ + MapConverter().convert(categoricalFeaturesInfo, + sc._gateway._gateway_client) + model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel( + dataBytes._jrdd, algo, + numClasses, categoricalFeaturesInfoJMap, + impurity, maxDepth, maxBins) + dataBytes.unpersist() + return DecisionTreeModel(sc, model) + + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index d94900cefdb77..639cda6350229 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -16,6 +16,7 @@ # import numpy as np +import warnings from pyspark.mllib.linalg import Vectors, SparseVector from pyspark.mllib.regression import LabeledPoint @@ -29,9 +30,9 @@ class MLUtils: Helper methods to load, save and pre-process data used in MLlib. """ - @deprecated @staticmethod def _parse_libsvm_line(line, multiclass): + warnings.warn("deprecated", DeprecationWarning) return _parse_libsvm_line(line) @staticmethod @@ -67,9 +68,9 @@ def _convert_labeled_point_to_libsvm(p): " but got " % type(v)) return " ".join(items) - @deprecated @staticmethod def loadLibSVMFile(sc, path, multiclass=False, numFeatures=-1, minPartitions=None): + warnings.warn("deprecated", DeprecationWarning) return loadLibSVMFile(sc, path, numFeatures, minPartitions) @staticmethod @@ -106,7 +107,6 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() >>> examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() - >>> multiclass_examples = MLUtils.loadLibSVMFile(sc, tempFile.name).collect() >>> tempFile.close() >>> type(examples[0]) == LabeledPoint True @@ -115,20 +115,18 @@ def loadLibSVMFile(sc, path, numFeatures=-1, minPartitions=None): >>> type(examples[1]) == LabeledPoint True >>> print examples[1] - (0.0,(6,[],[])) + (-1.0,(6,[],[])) >>> type(examples[2]) == LabeledPoint True >>> print examples[2] - (0.0,(6,[1,3,5],[4.0,5.0,6.0])) - >>> multiclass_examples[1].label - -1.0 + (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) """ lines = sc.textFile(path, minPartitions) parsed = lines.map(lambda l: MLUtils._parse_libsvm_line(l)) if numFeatures <= 0: parsed.cache() - numFeatures = parsed.map(lambda x: 0 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 + numFeatures = parsed.map(lambda x: -1 if x[1].size == 0 else x[1][-1]).reduce(max) + 1 return parsed.map(lambda x: LabeledPoint(x[0], Vectors.sparse(numFeatures, x[1], x[2]))) @staticmethod diff --git a/python/run-tests b/python/run-tests index 5049e15ce5f8a..48feba2f5bd63 100755 --- a/python/run-tests +++ b/python/run-tests @@ -71,6 +71,7 @@ run_test "pyspark/mllib/random.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" run_test "pyspark/mllib/tests.py" +run_test "pyspark/mllib/util.py" if [[ $FAILED == 0 ]]; then echo -en "\033[32m" # Green From 67bd8e3c217a80c3117a6e3853aa60fe13d08c91 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 2 Aug 2014 13:16:41 -0700 Subject: [PATCH 03/13] [SQL] Set outputPartitioning of BroadcastHashJoin correctly. I think we will not generate the plan triggering this bug at this moment. But, let me explain it... Right now, we are using `left.outputPartitioning` as the `outputPartitioning` of a `BroadcastHashJoin`. We may have a wrong physical plan for cases like... ```sql SELECT l.key, count(*) FROM (SELECT key, count(*) as cnt FROM src GROUP BY key) l // This is buildPlan JOIN r // This is the streamedPlan ON (l.cnt = r.value) GROUP BY l.key ``` Let's say we have a `BroadcastHashJoin` on `l` and `r`. For this case, we will pick `l`'s `outputPartitioning` for the `outputPartitioning`of the `BroadcastHashJoin` on `l` and `r`. Also, because the last `GROUP BY` is using `l.key` as the key, we will not introduce an `Exchange` for this aggregation. However, `r`'s outputPartitioning may not match the required distribution of the last `GROUP BY` and we fail to group data correctly. JIRA is being reindexed. I will create a JIRA ticket once it is back online. Author: Yin Huai Closes #1735 from yhuai/BroadcastHashJoin and squashes the following commits: 96d9cb3 [Yin Huai] Set outputPartitioning correctly. --- .../src/main/scala/org/apache/spark/sql/execution/joins.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index cc138c749949d..51bb61530744c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -405,8 +405,7 @@ case class BroadcastHashJoin( left: SparkPlan, right: SparkPlan) extends BinaryNode with HashJoin { - - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning override def requiredChildDistribution = UnspecifiedDistribution :: UnspecifiedDistribution :: Nil From 91f9504e6086fac05b40545099f9818949c24bca Mon Sep 17 00:00:00 2001 From: Chris Fregly Date: Sat, 2 Aug 2014 13:35:35 -0700 Subject: [PATCH 04/13] [SPARK-1981] Add AWS Kinesis streaming support Author: Chris Fregly Closes #1434 from cfregly/master and squashes the following commits: 4774581 [Chris Fregly] updated docs, renamed retry to retryRandom to be more clear, removed retries around store() method 0393795 [Chris Fregly] moved Kinesis examples out of examples/ and back into extras/kinesis-asl 691a6be [Chris Fregly] fixed tests and formatting, fixed a bug with JavaKinesisWordCount during union of streams 0e1c67b [Chris Fregly] Merge remote-tracking branch 'upstream/master' 74e5c7c [Chris Fregly] updated per TD's feedback. simplified examples, updated docs e33cbeb [Chris Fregly] Merge remote-tracking branch 'upstream/master' bf614e9 [Chris Fregly] per matei's feedback: moved the kinesis examples into the examples/ dir d17ca6d [Chris Fregly] per TD's feedback: updated docs, simplified the KinesisUtils api 912640c [Chris Fregly] changed the foundKinesis class to be a publically-avail class db3eefd [Chris Fregly] Merge remote-tracking branch 'upstream/master' 21de67f [Chris Fregly] Merge remote-tracking branch 'upstream/master' 6c39561 [Chris Fregly] parameterized the versions of the aws java sdk and kinesis client 338997e [Chris Fregly] improve build docs for kinesis 828f8ae [Chris Fregly] more cleanup e7c8978 [Chris Fregly] Merge remote-tracking branch 'upstream/master' cd68c0d [Chris Fregly] fixed typos and backward compatibility d18e680 [Chris Fregly] Merge remote-tracking branch 'upstream/master' b3b0ff1 [Chris Fregly] [SPARK-1981] Add AWS Kinesis streaming support --- bin/run-example | 3 +- bin/run-example2.cmd | 3 +- dev/audit-release/audit_release.py | 4 +- .../src/main/scala/SparkApp.scala | 7 + dev/audit-release/sbt_app_kinesis/build.sbt | 28 ++ .../src/main/scala/SparkApp.scala | 33 +++ dev/create-release/create-release.sh | 4 +- dev/run-tests | 3 + docs/streaming-custom-receivers.md | 4 +- docs/streaming-kinesis.md | 58 ++++ docs/streaming-programming-guide.md | 12 +- examples/pom.xml | 13 + extras/kinesis-asl/pom.xml | 96 ++++++ .../streaming/JavaKinesisWordCountASL.java | 180 ++++++++++++ .../src/main/resources/log4j.properties | 37 +++ .../streaming/KinesisWordCountASL.scala | 251 ++++++++++++++++ .../kinesis/KinesisCheckpointState.scala | 56 ++++ .../streaming/kinesis/KinesisReceiver.scala | 149 ++++++++++ .../kinesis/KinesisRecordProcessor.scala | 212 ++++++++++++++ .../streaming/kinesis/KinesisUtils.scala | 96 ++++++ .../kinesis/JavaKinesisStreamSuite.java | 41 +++ .../src/test/resources/log4j.properties | 26 ++ .../kinesis/KinesisReceiverSuite.scala | 275 ++++++++++++++++++ pom.xml | 10 + project/SparkBuild.scala | 6 +- 25 files changed, 1592 insertions(+), 15 deletions(-) create mode 100644 dev/audit-release/sbt_app_kinesis/build.sbt create mode 100644 dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala create mode 100644 docs/streaming-kinesis.md create mode 100644 extras/kinesis-asl/pom.xml create mode 100644 extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java create mode 100644 extras/kinesis-asl/src/main/resources/log4j.properties create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala create mode 100644 extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala create mode 100644 extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java create mode 100644 extras/kinesis-asl/src/test/resources/log4j.properties create mode 100644 extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala diff --git a/bin/run-example b/bin/run-example index 942706d733122..68a35702eddd3 100755 --- a/bin/run-example +++ b/bin/run-example @@ -29,7 +29,8 @@ if [ -n "$1" ]; then else echo "Usage: ./bin/run-example [example-args]" 1>&2 echo " - set MASTER=XX to use a specific master" 1>&2 - echo " - can use abbreviated example class name (e.g. SparkPi, mllib.LinearRegression)" 1>&2 + echo " - can use abbreviated example class name relative to com.apache.spark.examples" 1>&2 + echo " (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL)" 1>&2 exit 1 fi diff --git a/bin/run-example2.cmd b/bin/run-example2.cmd index eadedd7fa61ff..b29bf90c64e90 100644 --- a/bin/run-example2.cmd +++ b/bin/run-example2.cmd @@ -32,7 +32,8 @@ rem Test that an argument was given if not "x%1"=="x" goto arg_given echo Usage: run-example ^ [example-args] echo - set MASTER=XX to use a specific master - echo - can use abbreviated example class name (e.g. SparkPi, mllib.LinearRegression) + echo - can use abbreviated example class name relative to com.apache.spark.examples + echo (e.g. SparkPi, mllib.LinearRegression, streaming.KinesisWordCountASL) goto exit :arg_given diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 230e900ecd4de..16ea1a71290dc 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -105,7 +105,7 @@ def get_url(url): "spark-core", "spark-bagel", "spark-mllib", "spark-streaming", "spark-repl", "spark-graphx", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-mqtt", "spark-streaming-twitter", "spark-streaming-zeromq", - "spark-catalyst", "spark-sql", "spark-hive" + "spark-catalyst", "spark-sql", "spark-hive", "spark-streaming-kinesis-asl" ] modules = map(lambda m: "%s_%s" % (m, SCALA_BINARY_VERSION), modules) @@ -136,7 +136,7 @@ def ensure_path_not_present(x): os.chdir(original_dir) # SBT application tests -for app in ["sbt_app_core", "sbt_app_graphx", "sbt_app_streaming", "sbt_app_sql", "sbt_app_hive"]: +for app in ["sbt_app_core", "sbt_app_graphx", "sbt_app_streaming", "sbt_app_sql", "sbt_app_hive", "sbt_app_kinesis"]: os.chdir(app) ret = run_cmd("sbt clean run", exit_on_failure=False) test(ret == 0, "sbt application (%s)" % app) diff --git a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala index 77bbd167b199a..fc03fec9866a6 100644 --- a/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala +++ b/dev/audit-release/sbt_app_core/src/main/scala/SparkApp.scala @@ -50,5 +50,12 @@ object SimpleApp { println("Ganglia sink was loaded via spark-core") System.exit(-1) } + + // Remove kinesis from default build due to ASL license issue + val foundKinesis = Try(Class.forName("org.apache.spark.streaming.kinesis.KinesisUtils")).isSuccess + if (foundKinesis) { + println("Kinesis was loaded via spark-core") + System.exit(-1) + } } } diff --git a/dev/audit-release/sbt_app_kinesis/build.sbt b/dev/audit-release/sbt_app_kinesis/build.sbt new file mode 100644 index 0000000000000..981bc7957b5ed --- /dev/null +++ b/dev/audit-release/sbt_app_kinesis/build.sbt @@ -0,0 +1,28 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +name := "Kinesis Test" + +version := "1.0" + +scalaVersion := System.getenv.get("SCALA_VERSION") + +libraryDependencies += "org.apache.spark" %% "spark-streaming-kinesis-asl" % System.getenv.get("SPARK_VERSION") + +resolvers ++= Seq( + "Spark Release Repository" at System.getenv.get("SPARK_RELEASE_REPOSITORY"), + "Spray Repository" at "http://repo.spray.cc/") diff --git a/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala new file mode 100644 index 0000000000000..9f85066501472 --- /dev/null +++ b/dev/audit-release/sbt_app_kinesis/src/main/scala/SparkApp.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main.scala + +import scala.util.Try + +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ + +object SimpleApp { + def main(args: Array[String]) { + val foundKinesis = Try(Class.forName("org.apache.spark.streaming.kinesis.KinesisUtils")).isSuccess + if (!foundKinesis) { + println("Kinesis not loaded via kinesis-asl") + System.exit(-1) + } + } +} diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index af46572e6602b..42473629d4f15 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -53,15 +53,15 @@ if [[ ! "$@" =~ --package-only ]]; then -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \ -Dmaven.javadoc.skip=true \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ -Dtag=$GIT_TAG -DautoVersionSubmodules=true \ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ --batch-mode release:prepare mvn -DskipTests \ -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \ -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \ -Dmaven.javadoc.skip=true \ - -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl\ + -Pyarn -Phive -Phive-thriftserver -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \ release:perform cd .. diff --git a/dev/run-tests b/dev/run-tests index daa85bc750c07..d401c90f41d7b 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -36,6 +36,9 @@ fi if [ -z "$SBT_MAVEN_PROFILES_ARGS" ]; then export SBT_MAVEN_PROFILES_ARGS="-Pyarn -Phadoop-2.3 -Dhadoop.version=2.3.0" fi + +export SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Pkinesis-asl" + echo "SBT_MAVEN_PROFILES_ARGS=\"$SBT_MAVEN_PROFILES_ARGS\"" # Remove work directory diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index a2dc3a8961dfc..1e045a3dd0ca9 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -4,7 +4,7 @@ title: Spark Streaming Custom Receivers --- Spark Streaming can receive streaming data from any arbitrary data source beyond -the one's for which it has in-built support (that is, beyond Flume, Kafka, files, sockets, etc.). +the one's for which it has in-built support (that is, beyond Flume, Kafka, Kinesis, files, sockets, etc.). This requires the developer to implement a *receiver* that is customized for receiving data from the concerned data source. This guide walks through the process of implementing a custom receiver and using it in a Spark Streaming application. @@ -174,7 +174,7 @@ val words = lines.flatMap(_.split(" ")) ... {% endhighlight %} -The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/streaming/examples/CustomReceiver.scala). +The full source code is in the example [CustomReceiver.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala).
diff --git a/docs/streaming-kinesis.md b/docs/streaming-kinesis.md new file mode 100644 index 0000000000000..801c905c88df8 --- /dev/null +++ b/docs/streaming-kinesis.md @@ -0,0 +1,58 @@ +--- +layout: global +title: Spark Streaming Kinesis Receiver +--- + +### Kinesis +Build notes: +
  • Spark supports a Kinesis Streaming Receiver which is not included in the default build due to licensing restrictions.
  • +
  • _**Note that by embedding this library you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your Spark package**_.
  • +
  • The Spark Kinesis Streaming Receiver source code, examples, tests, and artifacts live in $SPARK_HOME/extras/kinesis-asl.
  • +
  • To build with Kinesis, you must run the maven or sbt builds with -Pkinesis-asl`.
  • +
  • Applications will need to link to the 'spark-streaming-kinesis-asl` artifact.
  • + +Kinesis examples notes: +
  • To build the Kinesis examples, you must run the maven or sbt builds with -Pkinesis-asl`.
  • +
  • These examples automatically determine the number of local threads and KinesisReceivers to spin up based on the number of shards for the stream.
  • +
  • KinesisWordCountProducerASL will generate random data to put onto the Kinesis stream for testing.
  • +
  • Checkpointing is disabled (no checkpoint dir is set). The examples as written will not recover from a driver failure.
  • + +Deployment and runtime notes: +
  • A single KinesisReceiver can process many shards of a stream.
  • +
  • Each shard of a stream is processed by one or more KinesisReceiver's managed by the Kinesis Client Library (KCL) Worker.
  • +
  • You never need more KinesisReceivers than the number of shards in your stream.
  • +
  • You can horizontally scale the receiving by creating more KinesisReceiver/DStreams (up to the number of shards for a given stream)
  • +
  • The Kinesis libraries must be present on all worker nodes, as they will need access to the Kinesis Client Library.
  • +
  • This code uses the DefaultAWSCredentialsProviderChain and searches for credentials in the following order of precedence:
    + 1) Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY
    + 2) Java System Properties - aws.accessKeyId and aws.secretKey
    + 3) Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs
    + 4) Instance profile credentials - delivered through the Amazon EC2 metadata service
    +
  • +
  • You need to setup a Kinesis stream with 1 or more shards per the following:
    + http://docs.aws.amazon.com/kinesis/latest/dev/step-one-create-stream.html
  • +
  • Valid Kinesis endpoint urls can be found here: Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region
  • +
  • When you first start up the KinesisReceiver, the Kinesis Client Library (KCL) needs ~30s to establish connectivity with the AWS Kinesis service, +retrieve any checkpoint data, and negotiate with other KCL's reading from the same stream.
  • +
  • Be careful when changing the app name. Kinesis maintains a mapping table in DynamoDB based on this app name (http://docs.aws.amazon.com/kinesis/latest/dev/kinesis-record-processor-implementation-app.html#kinesis-record-processor-initialization). +Changing the app name could lead to Kinesis errors as only 1 logical application can process a stream. In order to start fresh, +it's always best to delete the DynamoDB table that matches your app name. This DynamoDB table lives in us-east-1 regardless of the Kinesis endpoint URL.
  • + +Failure recovery notes: +
  • The combination of Spark Streaming and Kinesis creates 3 different checkpoints as follows:
    + 1) RDD data checkpoint (Spark Streaming) - frequency is configurable with DStream.checkpoint(Duration)
    + 2) RDD metadata checkpoint (Spark Streaming) - frequency is every DStream batch
    + 3) Kinesis checkpointing (Kinesis) - frequency is controlled by the developer calling ICheckpointer.checkpoint() directly
    +
  • +
  • Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling
  • +
  • Upon startup, a KinesisReceiver will begin processing records with sequence numbers greater than the last checkpoint sequence number recorded per shard.
  • +
  • If no checkpoint info exists, the worker will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) +or from the tip/latest (InitialPostitionInStream.LATEST). This is configurable.
  • +
  • When pulling from the stream tip (InitialPositionInStream.LATEST), only new stream data will be picked up after the KinesisReceiver starts.
  • +
  • InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no KinesisReceivers are running.
  • +
  • In production, you'll want to switch to InitialPositionInStream.TRIM_HORIZON which will read up to 24 hours (Kinesis limit) of previous stream data +depending on the checkpoint frequency.
  • +
  • InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records depending on the checkpoint frequency.
  • +
  • Record processing should be idempotent when possible.
  • +
  • Failed or latent KinesisReceivers will be detected and automatically shutdown/load-balanced by the KCL.
  • +
  • If possible, explicitly shutdown the worker if a failure occurs in order to trigger the final checkpoint.
  • diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 7b8b7933434c4..9f331ed50d2a4 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -9,7 +9,7 @@ title: Spark Streaming Programming Guide # Overview Spark Streaming is an extension of the core Spark API that allows enables high-throughput, fault-tolerant stream processing of live data streams. Data can be ingested from many sources -like Kafka, Flume, Twitter, ZeroMQ or plain old TCP sockets and be processed using complex +like Kafka, Flume, Twitter, ZeroMQ, Kinesis or plain old TCP sockets and be processed using complex algorithms expressed with high-level functions like `map`, `reduce`, `join` and `window`. Finally, processed data can be pushed out to filesystems, databases, and live dashboards. In fact, you can apply Spark's in-built @@ -38,7 +38,7 @@ stream of results in batches. Spark Streaming provides a high-level abstraction called *discretized stream* or *DStream*, which represents a continuous stream of data. DStreams can be created either from input data -stream from sources such as Kafka and Flume, or by applying high-level +stream from sources such as Kafka, Flume, and Kinesis, or by applying high-level operations on other DStreams. Internally, a DStream is represented as a sequence of [RDDs](api/scala/index.html#org.apache.spark.rdd.RDD). @@ -313,7 +313,7 @@ To write your own Spark Streaming program, you will have to add the following de artifactId = spark-streaming_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION}} -For ingesting data from sources like Kafka and Flume that are not present in the Spark +For ingesting data from sources like Kafka, Flume, and Kinesis that are not present in the Spark Streaming core API, you will have to add the corresponding artifact `spark-streaming-xyz_{{site.SCALA_BINARY_VERSION}}` to the dependencies. For example, @@ -327,6 +327,7 @@ some of the common ones are as follows. Twitter spark-streaming-twitter_{{site.SCALA_BINARY_VERSION}} ZeroMQ spark-streaming-zeromq_{{site.SCALA_BINARY_VERSION}} MQTT spark-streaming-mqtt_{{site.SCALA_BINARY_VERSION}} + Kinesis
    (built separately) kinesis-asl_{{site.SCALA_BINARY_VERSION}} @@ -442,7 +443,7 @@ see the API documentations of the relevant functions in Scala and [JavaStreamingContext](api/scala/index.html#org.apache.spark.streaming.api.java.JavaStreamingContext) for Java. -Additional functionality for creating DStreams from sources such as Kafka, Flume, and Twitter +Additional functionality for creating DStreams from sources such as Kafka, Flume, Kinesis, and Twitter can be imported by adding the right dependencies as explained in an [earlier](#linking) section. To take the case of Kafka, after adding the artifact `spark-streaming-kafka_{{site.SCALA_BINARY_VERSION}}` to the @@ -467,6 +468,9 @@ For more details on these additional sources, see the corresponding [API documen Furthermore, you can also implement your own custom receiver for your sources. See the [Custom Receiver Guide](streaming-custom-receivers.html). +### Kinesis +[Kinesis](streaming-kinesis.html) + ## Operations There are two kinds of DStream operations - _transformations_ and _output operations_. Similar to RDD transformations, DStream transformations operate on one or more DStreams to create new DStreams diff --git a/examples/pom.xml b/examples/pom.xml index c4ed0f5a6a02b..8c4c128bb484d 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,6 +34,19 @@ Spark Project Examples http://spark.apache.org/ + + + kinesis-asl + + + org.apache.spark + spark-streaming-kinesis-asl_${scala.binary.version} + ${project.version} + + + + + org.apache.spark diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml new file mode 100644 index 0000000000000..a54b34235dfb4 --- /dev/null +++ b/extras/kinesis-asl/pom.xml @@ -0,0 +1,96 @@ + + + + 4.0.0 + + org.apache.spark + spark-parent + 1.1.0-SNAPSHOT + ../../pom.xml + + + + org.apache.spark + spark-streaming-kinesis-asl_2.10 + jar + Spark Kinesis Integration + + + kinesis-asl + + + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + + + org.apache.spark + spark-streaming_${scala.binary.version} + ${project.version} + test-jar + test + + + com.amazonaws + amazon-kinesis-client + ${aws.kinesis.client.version} + + + com.amazonaws + aws-java-sdk + ${aws.java.sdk.version} + + + org.scalatest + scalatest_${scala.binary.version} + test + + + org.mockito + mockito-all + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.easymock + easymockclassextension + test + + + com.novocode + junit-interface + test + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.scalatest + scalatest-maven-plugin + + + + diff --git a/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java new file mode 100644 index 0000000000000..a8b907b241893 --- /dev/null +++ b/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.examples.streaming; + +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Pattern; + +import org.apache.log4j.Logger; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.kinesis.KinesisUtils; + +import scala.Tuple2; + +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.services.kinesis.AmazonKinesisClient; +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; +import com.google.common.collect.Lists; + +/** + * Java-friendly Kinesis Spark Streaming WordCount example + * + * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details + * on the Kinesis Spark Streaming integration. + * + * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard + * for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given + * and . + * + * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region + * + * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * + * Usage: JavaKinesisWordCountASL + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.JavaKinesisWordCountASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class called KinesisWordCountProducerASL which puts dummy data + * onto the Kinesis stream. + * Usage instructions for KinesisWordCountProducerASL are provided in the class definition. + */ +public final class JavaKinesisWordCountASL { + private static final Pattern WORD_SEPARATOR = Pattern.compile(" "); + private static final Logger logger = Logger.getLogger(JavaKinesisWordCountASL.class); + + /* Make the constructor private to enforce singleton */ + private JavaKinesisWordCountASL() { + } + + public static void main(String[] args) { + /* Check that all required args were passed in. */ + if (args.length < 2) { + System.err.println( + "|Usage: KinesisWordCount \n" + + "| is the name of the Kinesis stream\n" + + "| is the endpoint of the Kinesis service\n" + + "| (e.g. https://kinesis.us-east-1.amazonaws.com)\n"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + /* Populate the appropriate variables from the given args */ + String streamName = args[0]; + String endpointUrl = args[1]; + /* Set the batch interval to a fixed 2000 millis (2 seconds) */ + Duration batchInterval = new Duration(2000); + + /* Create a Kinesis client in order to determine the number of shards for the given stream */ + AmazonKinesisClient kinesisClient = new AmazonKinesisClient( + new DefaultAWSCredentialsProviderChain()); + kinesisClient.setEndpoint(endpointUrl); + + /* Determine the number of shards from the stream */ + int numShards = kinesisClient.describeStream(streamName) + .getStreamDescription().getShards().size(); + + /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard */ + int numStreams = numShards; + + /* Must add 1 more thread than the number of receivers or the output won't show properly from the driver */ + int numSparkThreads = numStreams + 1; + + /* Setup the Spark config. */ + SparkConf sparkConfig = new SparkConf().setAppName("KinesisWordCount").setMaster( + "local[" + numSparkThreads + "]"); + + /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + Duration checkpointInterval = batchInterval; + + /* Setup the StreamingContext */ + JavaStreamingContext jssc = new JavaStreamingContext(sparkConfig, batchInterval); + + /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + List> streamsList = new ArrayList>(numStreams); + for (int i = 0; i < numStreams; i++) { + streamsList.add( + KinesisUtils.createStream(jssc, streamName, endpointUrl, checkpointInterval, + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()) + ); + } + + /* Union all the streams if there is more than 1 stream */ + JavaDStream unionStreams; + if (streamsList.size() > 1) { + unionStreams = jssc.union(streamsList.get(0), streamsList.subList(1, streamsList.size())); + } else { + /* Otherwise, just use the 1 stream */ + unionStreams = streamsList.get(0); + } + + /* + * Split each line of the union'd DStreams into multiple words using flatMap to produce the collection. + * Convert lines of byte[] to multiple Strings by first converting to String, then splitting on WORD_SEPARATOR. + */ + JavaDStream words = unionStreams.flatMap(new FlatMapFunction() { + @Override + public Iterable call(byte[] line) { + return Lists.newArrayList(WORD_SEPARATOR.split(new String(line))); + } + }); + + /* Map each word to a (word, 1) tuple, then reduce/aggregate by word. */ + JavaPairDStream wordCounts = words.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }); + + /* Print the first 10 wordCounts */ + wordCounts.print(); + + /* Start the streaming context and await termination */ + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/extras/kinesis-asl/src/main/resources/log4j.properties b/extras/kinesis-asl/src/main/resources/log4j.properties new file mode 100644 index 0000000000000..97348fb5b6123 --- /dev/null +++ b/extras/kinesis-asl/src/main/resources/log4j.properties @@ -0,0 +1,37 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +log4j.rootCategory=WARN, console + +# File appender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Console appender +log4j.appender.console=org.apache.log4j.ConsoleAppender +log4j.appender.console.target=System.out +log4j.appender.console.layout=org.apache.log4j.PatternLayout +log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n + +# Settings to quiet third party logs that are too verbose +log4j.logger.org.eclipse.jetty=WARN +log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO +log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala new file mode 100644 index 0000000000000..d03edf8b30a9f --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming + +import java.nio.ByteBuffer +import scala.util.Random +import org.apache.spark.Logging +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Milliseconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions +import org.apache.spark.streaming.kinesis.KinesisUtils +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.services.kinesis.AmazonKinesisClient +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.model.PutRecordRequest +import org.apache.log4j.Logger +import org.apache.log4j.Level + +/** + * Kinesis Spark Streaming WordCount example. + * + * See http://spark.apache.org/docs/latest/streaming-kinesis.html for more details on + * the Kinesis Spark Streaming integration. + * + * This example spins up 1 Kinesis Worker (Spark Streaming Receiver) per shard + * for the given stream. + * It then starts pulling from the last checkpointed sequence number of the given + * and . + * + * Valid endpoint urls: http://docs.aws.amazon.com/general/latest/gr/rande.html#ak_region + * + * This code uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file - default location (~/.aws/credentials) shared by all AWS SDKs + * Instance profile credentials - delivered through the Amazon EC2 metadata service + * + * Usage: KinesisWordCountASL + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.KinesisWordCountASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com + * + * There is a companion helper class below called KinesisWordCountProducerASL which puts + * dummy data onto the Kinesis stream. + * Usage instructions for KinesisWordCountProducerASL are provided in that class definition. + */ +object KinesisWordCountASL extends Logging { + def main(args: Array[String]) { + /* Check that all required args were passed in. */ + if (args.length < 2) { + System.err.println( + """ + |Usage: KinesisWordCount + | is the name of the Kinesis stream + | is the endpoint of the Kinesis service + | (e.g. https://kinesis.us-east-1.amazonaws.com) + """.stripMargin) + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + /* Populate the appropriate variables from the given args */ + val Array(streamName, endpointUrl) = args + + /* Determine the number of shards from the stream */ + val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + kinesisClient.setEndpoint(endpointUrl) + val numShards = kinesisClient.describeStream(streamName).getStreamDescription().getShards() + .size() + + /* In this example, we're going to create 1 Kinesis Worker/Receiver/DStream for each shard. */ + val numStreams = numShards + + /* + * numSparkThreads should be 1 more thread than the number of receivers. + * This leaves one thread available for actually processing the data. + */ + val numSparkThreads = numStreams + 1 + + /* Setup the and SparkConfig and StreamingContext */ + /* Spark Streaming batch interval */ + val batchInterval = Milliseconds(2000) + val sparkConfig = new SparkConf().setAppName("KinesisWordCount") + .setMaster(s"local[$numSparkThreads]") + val ssc = new StreamingContext(sparkConfig, batchInterval) + + /* Kinesis checkpoint interval. Same as batchInterval for this example. */ + val kinesisCheckpointInterval = batchInterval + + /* Create the same number of Kinesis DStreams/Receivers as Kinesis stream's shards */ + val kinesisStreams = (0 until numStreams).map { i => + KinesisUtils.createStream(ssc, streamName, endpointUrl, kinesisCheckpointInterval, + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2) + } + + /* Union all the streams */ + val unionStreams = ssc.union(kinesisStreams) + + /* Convert each line of Array[Byte] to String, split into words, and count them */ + val words = unionStreams.flatMap(byteArray => new String(byteArray) + .split(" ")) + + /* Map each word to a (word, 1) tuple so we can reduce/aggregate by key. */ + val wordCounts = words.map(word => (word, 1)).reduceByKey(_ + _) + + /* Print the first 10 wordCounts */ + wordCounts.print() + + /* Start the streaming context and await termination */ + ssc.start() + ssc.awaitTermination() + } +} + +/** + * Usage: KinesisWordCountProducerASL + * + * is the name of the Kinesis stream (ie. mySparkStream) + * is the endpoint of the Kinesis service + * (ie. https://kinesis.us-east-1.amazonaws.com) + * is the rate of records per second to put onto the stream + * is the rate of records per second to put onto the stream + * + * Example: + * $ export AWS_ACCESS_KEY_ID= + * $ export AWS_SECRET_KEY= + * $ $SPARK_HOME/bin/run-example \ + * org.apache.spark.examples.streaming.KinesisWordCountProducerASL mySparkStream \ + * https://kinesis.us-east-1.amazonaws.com 10 5 + */ +object KinesisWordCountProducerASL { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: KinesisWordCountProducerASL " + + " ") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + /* Populate the appropriate variables from the given args */ + val Array(stream, endpoint, recordsPerSecond, wordsPerRecord) = args + + /* Generate the records and return the totals */ + val totals = generate(stream, endpoint, recordsPerSecond.toInt, wordsPerRecord.toInt) + + /* Print the array of (index, total) tuples */ + println("Totals") + totals.foreach(total => println(total.toString())) + } + + def generate(stream: String, + endpoint: String, + recordsPerSecond: Int, + wordsPerRecord: Int): Seq[(Int, Int)] = { + + val MaxRandomInts = 10 + + /* Create the Kinesis client */ + val kinesisClient = new AmazonKinesisClient(new DefaultAWSCredentialsProviderChain()) + kinesisClient.setEndpoint(endpoint) + + println(s"Putting records onto stream $stream and endpoint $endpoint at a rate of" + + s" $recordsPerSecond records per second and $wordsPerRecord words per record"); + + val totals = new Array[Int](MaxRandomInts) + /* Put String records onto the stream per the given recordPerSec and wordsPerRecord */ + for (i <- 1 to 5) { + + /* Generate recordsPerSec records to put onto the stream */ + val records = (1 to recordsPerSecond.toInt).map { recordNum => + /* + * Randomly generate each wordsPerRec words between 0 (inclusive) + * and MAX_RANDOM_INTS (exclusive) + */ + val data = (1 to wordsPerRecord.toInt).map(x => { + /* Generate the random int */ + val randomInt = Random.nextInt(MaxRandomInts) + + /* Keep track of the totals */ + totals(randomInt) += 1 + + randomInt.toString() + }).mkString(" ") + + /* Create a partitionKey based on recordNum */ + val partitionKey = s"partitionKey-$recordNum" + + /* Create a PutRecordRequest with an Array[Byte] version of the data */ + val putRecordRequest = new PutRecordRequest().withStreamName(stream) + .withPartitionKey(partitionKey) + .withData(ByteBuffer.wrap(data.getBytes())); + + /* Put the record onto the stream and capture the PutRecordResult */ + val putRecordResult = kinesisClient.putRecord(putRecordRequest); + } + + /* Sleep for a second */ + Thread.sleep(1000) + println("Sent " + recordsPerSecond + " records") + } + + /* Convert the totals to (index, total) tuple */ + (0 to (MaxRandomInts - 1)).zip(totals) + } +} + +/** + * Utility functions for Spark Streaming examples. + * This has been lifted from the examples/ project to remove the circular dependency. + */ +object StreamingExamples extends Logging { + + /** Set reasonable logging levels for streaming if the user has not configured log4j. */ + def setStreamingLogLevels() { + val log4jInitialized = Logger.getRootLogger.getAllAppenders.hasMoreElements + if (!log4jInitialized) { + // We first log something to initialize Spark's default logging, then we override the + // logging level. + logInfo("Setting log level to [WARN] for streaming example." + + " To override add a custom log4j.properties to the classpath.") + Logger.getRootLogger.setLevel(Level.WARN) + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala new file mode 100644 index 0000000000000..0b80b611cdce7 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import org.apache.spark.Logging +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.util.Clock +import org.apache.spark.streaming.util.ManualClock +import org.apache.spark.streaming.util.SystemClock + +/** + * This is a helper class for managing checkpoint clocks. + * + * @param checkpointInterval + * @param currentClock. Default to current SystemClock if none is passed in (mocking purposes) + */ +private[kinesis] class KinesisCheckpointState( + checkpointInterval: Duration, + currentClock: Clock = new SystemClock()) + extends Logging { + + /* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */ + val checkpointClock = new ManualClock() + checkpointClock.setTime(currentClock.currentTime() + checkpointInterval.milliseconds) + + /** + * Check if it's time to checkpoint based on the current time and the derived time + * for the next checkpoint + * + * @return true if it's time to checkpoint + */ + def shouldCheckpoint(): Boolean = { + new SystemClock().currentTime() > checkpointClock.currentTime() + } + + /** + * Advance the checkpoint clock by the checkpoint interval. + */ + def advanceCheckpoint() = { + checkpointClock.addToTime(checkpointInterval.milliseconds) + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala new file mode 100644 index 0000000000000..1bd1f324298e7 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.net.InetAddress +import java.util.UUID + +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.receiver.Receiver + +import com.amazonaws.auth.AWSCredentialsProvider +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker + +/** + * Custom AWS Kinesis-specific implementation of Spark Streaming's Receiver. + * This implementation relies on the Kinesis Client Library (KCL) Worker as described here: + * https://github.com/awslabs/amazon-kinesis-client + * This is a custom receiver used with StreamingContext.receiverStream(Receiver) + * as described here: + * http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Instances of this class will get shipped to the Spark Streaming Workers + * to run within a Spark Executor. + * + * @param appName Kinesis application name. Kinesis Apps are mapped to Kinesis Streams + * by the Kinesis Client Library. If you change the App name or Stream name, + * the KCL will throw errors. This usually requires deleting the backing + * DynamoDB table with the same name this Kinesis application. + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return ReceiverInputDStream[Array[Byte]] + */ +private[kinesis] class KinesisReceiver( + appName: String, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel) + extends Receiver[Array[Byte]](storageLevel) with Logging { receiver => + + /* + * The following vars are built in the onStart() method which executes in the Spark Worker after + * this code is serialized and shipped remotely. + */ + + /* + * workerId should be based on the ip address of the actual Spark Worker where this code runs + * (not the Driver's ip address.) + */ + var workerId: String = null + + /* + * This impl uses the DefaultAWSCredentialsProviderChain and searches for credentials + * in the following order of precedence: + * Environment Variables - AWS_ACCESS_KEY_ID and AWS_SECRET_KEY + * Java System Properties - aws.accessKeyId and aws.secretKey + * Credential profiles file at the default location (~/.aws/credentials) shared by all + * AWS SDKs and the AWS CLI + * Instance profile credentials delivered through the Amazon EC2 metadata service + */ + var credentialsProvider: AWSCredentialsProvider = null + + /* KCL config instance. */ + var kinesisClientLibConfiguration: KinesisClientLibConfiguration = null + + /* + * RecordProcessorFactory creates impls of IRecordProcessor. + * IRecordProcessor adapts the KCL to our Spark KinesisReceiver via the + * IRecordProcessor.processRecords() method. + * We're using our custom KinesisRecordProcessor in this case. + */ + var recordProcessorFactory: IRecordProcessorFactory = null + + /* + * Create a Kinesis Worker. + * This is the core client abstraction from the Kinesis Client Library (KCL). + * We pass the RecordProcessorFactory from above as well as the KCL config instance. + * A Kinesis Worker can process 1..* shards from the given stream - each with its + * own RecordProcessor. + */ + var worker: Worker = null + + /** + * This is called when the KinesisReceiver starts and must be non-blocking. + * The KCL creates and manages the receiving/processing thread pool through the Worker.run() + * method. + */ + override def onStart() { + workerId = InetAddress.getLocalHost.getHostAddress() + ":" + UUID.randomUUID() + credentialsProvider = new DefaultAWSCredentialsProviderChain() + kinesisClientLibConfiguration = new KinesisClientLibConfiguration(appName, streamName, + credentialsProvider, workerId).withKinesisEndpoint(endpointUrl) + .withInitialPositionInStream(initialPositionInStream).withTaskBackoffTimeMillis(500) + recordProcessorFactory = new IRecordProcessorFactory { + override def createProcessor: IRecordProcessor = new KinesisRecordProcessor(receiver, + workerId, new KinesisCheckpointState(checkpointInterval)) + } + worker = new Worker(recordProcessorFactory, kinesisClientLibConfiguration) + worker.run() + logInfo(s"Started receiver with workerId $workerId") + } + + /** + * This is called when the KinesisReceiver stops. + * The KCL worker.shutdown() method stops the receiving/processing threads. + * The KCL will do its best to drain and checkpoint any in-flight records upon shutdown. + */ + override def onStop() { + worker.shutdown() + logInfo(s"Shut down receiver with workerId $workerId") + workerId = null + credentialsProvider = null + kinesisClientLibConfiguration = null + recordProcessorFactory = null + worker = null + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala new file mode 100644 index 0000000000000..8ecc2d90160b1 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.util.List + +import scala.collection.JavaConversions.asScalaBuffer +import scala.util.Random + +import org.apache.spark.Logging + +import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record + +/** + * Kinesis-specific implementation of the Kinesis Client Library (KCL) IRecordProcessor. + * This implementation operates on the Array[Byte] from the KinesisReceiver. + * The Kinesis Worker creates an instance of this KinesisRecordProcessor upon startup. + * + * @param receiver Kinesis receiver + * @param workerId for logging purposes + * @param checkpointState represents the checkpoint state including the next checkpoint time. + * It's injected here for mocking purposes. + */ +private[kinesis] class KinesisRecordProcessor( + receiver: KinesisReceiver, + workerId: String, + checkpointState: KinesisCheckpointState) extends IRecordProcessor with Logging { + + /* shardId to be populated during initialize() */ + var shardId: String = _ + + /** + * The Kinesis Client Library calls this method during IRecordProcessor initialization. + * + * @param shardId assigned by the KCL to this particular RecordProcessor. + */ + override def initialize(shardId: String) { + logInfo(s"Initialize: Initializing workerId $workerId with shardId $shardId") + this.shardId = shardId + } + + /** + * This method is called by the KCL when a batch of records is pulled from the Kinesis stream. + * This is the record-processing bridge between the KCL's IRecordProcessor.processRecords() + * and Spark Streaming's Receiver.store(). + * + * @param batch list of records from the Kinesis stream shard + * @param checkpointer used to update Kinesis when this batch has been processed/stored + * in the DStream + */ + override def processRecords(batch: List[Record], checkpointer: IRecordProcessorCheckpointer) { + if (!receiver.isStopped()) { + try { + /* + * Note: If we try to store the raw ByteBuffer from record.getData(), the Spark Streaming + * Receiver.store(ByteBuffer) attempts to deserialize the ByteBuffer using the + * internally-configured Spark serializer (kryo, etc). + * This is not desirable, so we instead store a raw Array[Byte] and decouple + * ourselves from Spark's internal serialization strategy. + */ + batch.foreach(record => receiver.store(record.getData().array())) + + logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") + + /* + * Checkpoint the sequence number of the last record successfully processed/stored + * in the batch. + * In this implementation, we're checkpointing after the given checkpointIntervalMillis. + * Note that this logic requires that processRecords() be called AND that it's time to + * checkpoint. I point this out because there is no background thread running the + * checkpointer. Checkpointing is tested and trigger only when a new batch comes in. + * If the worker is shutdown cleanly, checkpoint will happen (see shutdown() below). + * However, if the worker dies unexpectedly, a checkpoint may not happen. + * This could lead to records being processed more than once. + */ + if (checkpointState.shouldCheckpoint()) { + /* Perform the checkpoint */ + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + + /* Update the next checkpoint time */ + checkpointState.advanceCheckpoint() + + logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" + + s" records for shardId $shardId") + logDebug(s"Checkpoint: Next checkpoint is at " + + s" ${checkpointState.checkpointClock.currentTime()} for shardId $shardId") + } + } catch { + case e: Throwable => { + /* + * If there is a failure within the batch, the batch will not be checkpointed. + * This will potentially cause records since the last checkpoint to be processed + * more than once. + */ + logError(s"Exception: WorkerId $workerId encountered and exception while storing " + + " or checkpointing a batch for workerId $workerId and shardId $shardId.", e) + + /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor.*/ + throw e + } + } + } else { + /* RecordProcessor has been stopped. */ + logInfo(s"Stopped: The Spark KinesisReceiver has stopped for workerId $workerId" + + s" and shardId $shardId. No more records will be processed.") + } + } + + /** + * Kinesis Client Library is shutting down this Worker for 1 of 2 reasons: + * 1) the stream is resharding by splitting or merging adjacent shards + * (ShutdownReason.TERMINATE) + * 2) the failed or latent Worker has stopped sending heartbeats for whatever reason + * (ShutdownReason.ZOMBIE) + * + * @param checkpointer used to perform a Kinesis checkpoint for ShutdownReason.TERMINATE + * @param reason for shutdown (ShutdownReason.TERMINATE or ShutdownReason.ZOMBIE) + */ + override def shutdown(checkpointer: IRecordProcessorCheckpointer, reason: ShutdownReason) { + logInfo(s"Shutdown: Shutting down workerId $workerId with reason $reason") + reason match { + /* + * TERMINATE Use Case. Checkpoint. + * Checkpoint to indicate that all records from the shard have been drained and processed. + * It's now OK to read from the new shards that resulted from a resharding event. + */ + case ShutdownReason.TERMINATE => + KinesisRecordProcessor.retryRandom(checkpointer.checkpoint(), 4, 100) + + /* + * ZOMBIE Use Case. NoOp. + * No checkpoint because other workers may have taken over and already started processing + * the same records. + * This may lead to records being processed more than once. + */ + case ShutdownReason.ZOMBIE => + + /* Unknown reason. NoOp */ + case _ => + } + } +} + +private[kinesis] object KinesisRecordProcessor extends Logging { + /** + * Retry the given amount of times with a random backoff time (millis) less than the + * given maxBackOffMillis + * + * @param expression expression to evalute + * @param numRetriesLeft number of retries left + * @param maxBackOffMillis: max millis between retries + * + * @return evaluation of the given expression + * @throws Unretryable exception, unexpected exception, + * or any exception that persists after numRetriesLeft reaches 0 + */ + @annotation.tailrec + def retryRandom[T](expression: => T, numRetriesLeft: Int, maxBackOffMillis: Int): T = { + util.Try { expression } match { + /* If the function succeeded, evaluate to x. */ + case util.Success(x) => x + /* If the function failed, either retry or throw the exception */ + case util.Failure(e) => e match { + /* Retry: Throttling or other Retryable exception has occurred */ + case _: ThrottlingException | _: KinesisClientLibDependencyException if numRetriesLeft > 1 + => { + val backOffMillis = Random.nextInt(maxBackOffMillis) + Thread.sleep(backOffMillis) + logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) + retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) + } + /* Throw: Shutdown has been requested by the Kinesis Client Library.*/ + case _: ShutdownException => { + logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) + throw e + } + /* Throw: Non-retryable exception has occurred with the Kinesis Client Library */ + case _: InvalidStateException => { + logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" + + s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e) + throw e + } + /* Throw: Unexpected exception has occurred */ + case _ => { + logError(s"Unexpected, non-retryable exception.", e) + throw e + } + } + } + } +} diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala new file mode 100644 index 0000000000000..713cac0e293c0 --- /dev/null +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisUtils.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import org.apache.spark.annotation.Experimental +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Duration +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream +import org.apache.spark.streaming.api.java.JavaStreamingContext +import org.apache.spark.streaming.dstream.ReceiverInputDStream + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + + +/** + * Helper class to create Amazon Kinesis Input Stream + * :: Experimental :: + */ +@Experimental +object KinesisUtils { + /** + * Create an InputDStream that pulls messages from a Kinesis stream. + * + * @param ssc StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return ReceiverInputDStream[Array[Byte]] + */ + def createStream( + ssc: StreamingContext, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel): ReceiverInputDStream[Array[Byte]] = { + ssc.receiverStream(new KinesisReceiver(ssc.sc.appName, streamName, endpointUrl, + checkpointInterval, initialPositionInStream, storageLevel)) + } + + /** + * Create a Java-friendly InputDStream that pulls messages from a Kinesis stream. + * + * @param jssc Java StreamingContext object + * @param ssc StreamingContext object + * @param streamName Kinesis stream name + * @param endpointUrl Url of Kinesis service (e.g., https://kinesis.us-east-1.amazonaws.com) + * @param checkpointInterval Checkpoint interval for Kinesis checkpointing. + * See the Kinesis Spark Streaming documentation for more + * details on the different types of checkpoints. + * @param initialPositionInStream In the absence of Kinesis checkpoint info, this is the + * worker's initial starting position in the stream. + * The values are either the beginning of the stream + * per Kinesis' limit of 24 hours + * (InitialPositionInStream.TRIM_HORIZON) or + * the tip of the stream (InitialPositionInStream.LATEST). + * @param storageLevel Storage level to use for storing the received objects + * + * @return JavaReceiverInputDStream[Array[Byte]] + */ + def createStream( + jssc: JavaStreamingContext, + streamName: String, + endpointUrl: String, + checkpointInterval: Duration, + initialPositionInStream: InitialPositionInStream, + storageLevel: StorageLevel): JavaReceiverInputDStream[Array[Byte]] = { + jssc.receiverStream(new KinesisReceiver(jssc.ssc.sc.appName, streamName, + endpointUrl, checkpointInterval, initialPositionInStream, storageLevel)) + } +} diff --git a/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java new file mode 100644 index 0000000000000..87954a31f60ce --- /dev/null +++ b/extras/kinesis-asl/src/test/java/org/apache/spark/streaming/kinesis/JavaKinesisStreamSuite.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.kinesis; + +import org.apache.spark.storage.StorageLevel; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.LocalJavaStreamingContext; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.junit.Test; + +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + +/** + * Demonstrate the use of the KinesisUtils Java API + */ +public class JavaKinesisStreamSuite extends LocalJavaStreamingContext { + @Test + public void testKinesisStream() { + // Tests the API, does not actually test data receiving + JavaDStream kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", new Duration(2000), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2()); + + ssc.stop(); + } +} diff --git a/extras/kinesis-asl/src/test/resources/log4j.properties b/extras/kinesis-asl/src/test/resources/log4j.properties new file mode 100644 index 0000000000000..e01e049595475 --- /dev/null +++ b/extras/kinesis-asl/src/test/resources/log4j.properties @@ -0,0 +1,26 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +log4j.rootCategory=INFO, file +# log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file=org.apache.log4j.FileAppender +log4j.appender.file.append=false +log4j.appender.file.file=target/unit-tests.log +log4j.appender.file.layout=org.apache.log4j.PatternLayout +log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n + +# Ignore messages below warning level from Jetty, because it's a bit verbose +log4j.logger.org.eclipse.jetty=WARN diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala new file mode 100644 index 0000000000000..41dbd64c2b1fa --- /dev/null +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.streaming.kinesis + +import java.nio.ByteBuffer + +import scala.collection.JavaConversions.seqAsJavaList + +import org.apache.spark.annotation.Experimental +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.Milliseconds +import org.apache.spark.streaming.Seconds +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.TestSuiteBase +import org.apache.spark.streaming.util.Clock +import org.apache.spark.streaming.util.ManualClock +import org.scalatest.BeforeAndAfter +import org.scalatest.Matchers +import org.scalatest.mock.EasyMockSugar + +import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException +import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException +import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer +import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream +import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason +import com.amazonaws.services.kinesis.model.Record + +/** + * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor + */ +class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter + with EasyMockSugar { + + val app = "TestKinesisReceiver" + val stream = "mySparkStream" + val endpoint = "endpoint-url" + val workerId = "dummyWorkerId" + val shardId = "dummyShardId" + + val record1 = new Record() + record1.setData(ByteBuffer.wrap("Spark In Action".getBytes())) + val record2 = new Record() + record2.setData(ByteBuffer.wrap("Learning Spark".getBytes())) + val batch = List[Record](record1, record2) + + var receiverMock: KinesisReceiver = _ + var checkpointerMock: IRecordProcessorCheckpointer = _ + var checkpointClockMock: ManualClock = _ + var checkpointStateMock: KinesisCheckpointState = _ + var currentClockMock: Clock = _ + + override def beforeFunction() = { + receiverMock = mock[KinesisReceiver] + checkpointerMock = mock[IRecordProcessorCheckpointer] + checkpointClockMock = mock[ManualClock] + checkpointStateMock = mock[KinesisCheckpointState] + currentClockMock = mock[Clock] + } + + test("kinesis utils api") { + val ssc = new StreamingContext(master, framework, batchDuration) + // Tests the API, does not actually test data receiving + val kinesisStream = KinesisUtils.createStream(ssc, "mySparkStream", + "https://kinesis.us-west-2.amazonaws.com", Seconds(2), + InitialPositionInStream.LATEST, StorageLevel.MEMORY_AND_DISK_2); + ssc.stop() + } + + test("process records including store and checkpoint") { + val expectedCheckpointIntervalMillis = 10 + expecting { + receiverMock.isStopped().andReturn(false).once() + receiverMock.store(record1.getData().array()).once() + receiverMock.store(record2.getData().array()).once() + checkpointStateMock.shouldCheckpoint().andReturn(true).once() + checkpointerMock.checkpoint().once() + checkpointStateMock.advanceCheckpoint().once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + + test("shouldn't store and checkpoint when receiver is stopped") { + expecting { + receiverMock.isStopped().andReturn(true).once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + + test("shouldn't checkpoint when exception occurs during store") { + expecting { + receiverMock.isStopped().andReturn(false).once() + receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once() + } + whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) { + intercept[RuntimeException] { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.processRecords(batch, checkpointerMock) + } + } + } + + test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointIntervalMillis = 10 + val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) + } + } + + test("should checkpoint if we have exceeded the checkpoint interval") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock) + assert(checkpointState.shouldCheckpoint()) + } + } + + test("shouldn't checkpoint if we have not exceeded the checkpoint interval") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock) + assert(!checkpointState.shouldCheckpoint()) + } + } + + test("should add to time when advancing checkpoint") { + expecting { + currentClockMock.currentTime().andReturn(0).once() + } + whenExecuting(currentClockMock) { + val checkpointIntervalMillis = 10 + val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock) + assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis) + checkpointState.advanceCheckpoint() + assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis)) + } + } + + test("shutdown should checkpoint if the reason is TERMINATE") { + expecting { + checkpointerMock.checkpoint().once() + } + whenExecuting(checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + val reason = ShutdownReason.TERMINATE + recordProcessor.shutdown(checkpointerMock, reason) + } + } + + test("shutdown should not checkpoint if the reason is something other than TERMINATE") { + expecting { + } + whenExecuting(checkpointerMock, checkpointStateMock) { + val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, + checkpointStateMock) + recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE) + recordProcessor.shutdown(checkpointerMock, null) + } + } + + test("retry success on first attempt") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry success on second attempt after a Kinesis throttling exception") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andThrow(new ThrottlingException("error message")) + .andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry success on second attempt after a Kinesis dependency exception") { + val expectedIsStopped = false + expecting { + receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message")) + .andReturn(expectedIsStopped).once() + } + whenExecuting(receiverMock) { + val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100) + assert(actualVal == expectedIsStopped) + } + } + + test("retry failed after a shutdown exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[ShutdownException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after an invalid state exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[InvalidStateException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after unexpected exception") { + expecting { + checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once() + } + whenExecuting(checkpointerMock) { + intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + } + } + + test("retry failed after exhausing all retries") { + val expectedErrorMessage = "final try error message" + expecting { + checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message")) + .andThrow(new ThrottlingException(expectedErrorMessage)).once() + } + whenExecuting(checkpointerMock) { + val exception = intercept[RuntimeException] { + KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100) + } + exception.getMessage().shouldBe(expectedErrorMessage) + } + } +} diff --git a/pom.xml b/pom.xml index 99ae4b8b33f94..a42759169149b 100644 --- a/pom.xml +++ b/pom.xml @@ -134,6 +134,8 @@ 3.0.0 1.7.6 0.7.1 + 1.8.3 + 1.1.0 64m 512m @@ -1011,6 +1013,14 @@ + + + kinesis-asl + + extras/kinesis-asl + + + java8-tests diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1d7cc6dd6aef3..aac621fe53938 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -37,8 +37,8 @@ object BuildCommons { "spark", "sql", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq").map(ProjectRef(buildLocation, _)) - val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl) = - Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl") + val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) = + Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl") .map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples") @@ -62,7 +62,7 @@ object SparkBuild extends PomBuild { var isAlphaYarn = false var profiles: mutable.Seq[String] = mutable.Seq.empty if (Properties.envOrNone("SPARK_GANGLIA_LGPL").isDefined) { - println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pganglia-lgpl flag.") + println("NOTE: SPARK_GANGLIA_LGPL is deprecated, please use -Pspark-ganglia-lgpl flag.") profiles ++= Seq("spark-ganglia-lgpl") } if (Properties.envOrNone("SPARK_HIVE").isDefined) { From 4c477117bb1ffef463776c86f925d35036f96b7a Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 2 Aug 2014 13:55:28 -0700 Subject: [PATCH 05/13] SPARK-2804: Remove scalalogging-slf4j dependency This also Closes #1701. Author: GuoQiang Li Closes #1208 from witgo/SPARK-1470 and squashes the following commits: 422646b [GuoQiang Li] Remove scalalogging-slf4j dependency --- .../main/scala/org/apache/spark/Logging.scala | 10 ++++++--- sql/catalyst/pom.xml | 5 ----- .../sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../catalyst/analysis/HiveTypeCoercion.scala | 8 +++---- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../codegen/GenerateOrdering.scala | 4 ++-- .../apache/spark/sql/catalyst/package.scala | 1 - .../sql/catalyst/planning/QueryPlanner.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 6 ++--- .../spark/sql/catalyst/rules/Rule.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 12 +++++----- .../spark/sql/catalyst/trees/package.scala | 8 ++++--- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../CompressibleColumnBuilder.scala | 5 +++-- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 2 +- .../scala/org/apache/spark/sql/package.scala | 2 -- .../spark/sql/columnar/ColumnTypeSuite.scala | 4 ++-- .../hive/thriftserver/HiveThriftServer2.scala | 12 +++++----- .../hive/thriftserver/SparkSQLCLIDriver.scala | 2 +- .../hive/thriftserver/SparkSQLDriver.scala | 6 ++--- .../sql/hive/thriftserver/SparkSQLEnv.scala | 6 ++--- .../server/SparkSQLOperationManager.scala | 13 ++++++----- .../thriftserver/HiveThriftServer2Suite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 ++- .../org/apache/spark/sql/hive/TestHive.scala | 10 ++++----- .../org/apache/spark/sql/hive/hiveUdfs.scala | 4 ++-- .../hive/execution/HiveComparisonTest.scala | 22 +++++++++---------- .../hive/execution/HiveQueryFileTest.scala | 2 +- 30 files changed, 83 insertions(+), 82 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala index 807ef3e9c9d60..d4f2624061e35 100644 --- a/core/src/main/scala/org/apache/spark/Logging.scala +++ b/core/src/main/scala/org/apache/spark/Logging.scala @@ -39,13 +39,17 @@ trait Logging { // be serialized and used on another machine @transient private var log_ : Logger = null + // Method to get the logger name for this object + protected def logName = { + // Ignore trailing $'s in the class names for Scala objects + this.getClass.getName.stripSuffix("$") + } + // Method to get or create the logger for this object protected def log: Logger = { if (log_ == null) { initializeIfNecessary() - var className = this.getClass.getName - // Ignore trailing $'s in the class names for Scala objects - log_ = LoggerFactory.getLogger(className.stripSuffix("$")) + log_ = LoggerFactory.getLogger(logName) } log_ } diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 54fa96baa1e18..58d44e7923bee 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -54,11 +54,6 @@ spark-core_${scala.binary.version} ${project.version} - - com.typesafe - scalalogging-slf4j_${scala.binary.version} - 1.0.1 - org.scalatest scalatest_${scala.binary.version} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 74c0104e5b17f..2ba68cab115fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -109,12 +109,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case q: LogicalPlan if q.childrenResolved => - logger.trace(s"Attempting to resolve ${q.simpleString}") + logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = q.resolve(name).getOrElse(u) - logger.debug(s"Resolving $u to $result") + logDebug(s"Resolving $u to $result") result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 47c7ad076ad07..e94f2a3bea63e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -75,7 +75,7 @@ trait HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") newType } } @@ -154,7 +154,7 @@ trait HiveTypeCoercion { (Alias(Cast(l, StringType), l.name)(), r) case (l, r) if l.dataType != r.dataType => - logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") findTightestCommonType(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() @@ -170,7 +170,7 @@ trait HiveTypeCoercion { val newLeft = if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") + logDebug(s"Widening numeric types in union $castedLeft ${left.output}") Project(castedLeft, left) } else { left @@ -178,7 +178,7 @@ trait HiveTypeCoercion { val newRight = if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedRight ${right.output}") + logDebug(s"Widening numeric types in union $castedRight ${right.output}") Project(castedRight, right) } else { right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index f38f99569f207..0913f15888780 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.trees diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4211998f7511a..094ff14552283 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import com.typesafe.scalalogging.slf4j.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{StringType, NumericType} @@ -92,7 +92,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[Row]] wit } new $orderingName() """ - logger.debug(s"Generated Ordering: $code") + logDebug(s"Generated Ordering: $code") toolBox.eval(code).asInstanceOf[Ordering[Row]] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala index ca9642954eb27..bdd07bbeb2230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -25,5 +25,4 @@ package object catalyst { */ protected[catalyst] object ScalaReflectionLock - protected[catalyst] type Logging = com.typesafe.scalalogging.slf4j.Logging } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 781ba489b44c6..5839c9f7c43ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index bc763a4e06e67..90923fe31a063 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -184,7 +184,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => - logger.debug(s"Considering join on: $condition") + logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val (joinPredicates, otherPredicates) = @@ -202,7 +202,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index f8960b3fe7a17..03414b2301e81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 6aa407c836aec..d192b151ac1c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.catalyst.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -60,7 +60,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { case (plan, rule) => val result = rule(plan) if (!result.fastEquals(plan)) { - logger.trace( + logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} @@ -73,26 +73,26 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { if (iteration > batch.strategy.maxIterations) { // Only log if this is a rule that is supposed to run more than once. if (iteration != 2) { - logger.info(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") + logInfo(s"Max iterations (${iteration - 1}) reached for batch ${batch.name}") } continue = false } if (curPlan.fastEquals(lastPlan)) { - logger.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") + logTrace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") continue = false } lastPlan = curPlan } if (!batchStartPlan.fastEquals(curPlan)) { - logger.debug( + logDebug( s""" |=== Result of Batch ${batch.name} === |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { - logger.trace(s"Batch ${batch.name} has no effect.") + logTrace(s"Batch ${batch.name} has no effect.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index 9a28d035a10a3..d725a92c06f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import org.apache.spark.Logging + /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are * granted the following interface: @@ -31,8 +33,8 @@ package org.apache.spark.sql.catalyst *
  • debugging support - pretty printing, easy splicing of trees, etc.
  • * */ -package object trees { +package object trees extends Logging { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. - protected val logger = - com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger("catalyst.trees")) + protected override def logName = "catalyst.trees" + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index dad71079c29b9..00dd34aabc389 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkContext} /** * :: AlphaComponent :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala index 4c6675c3c87bf..6ad12a0dcb64d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressibleColumnBuilder.scala @@ -19,7 +19,8 @@ package org.apache.spark.sql.columnar.compression import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.Logging +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar.{ColumnBuilder, NativeColumnBuilder} @@ -101,7 +102,7 @@ private[sql] trait CompressibleColumnBuilder[T <: NativeType] copyColumnHeader(rawBuffer, compressedBuffer) - logger.info(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") + logInfo(s"Compressor for [$columnName]: $encoder, ratio: ${encoder.compressionRatio}") encoder.compress(rawBuffer, compressedBuffer, columnType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 30712f03cab4c..77dc2ad733215 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -101,7 +101,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) - logger.debug( + logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 70db1ebd3a3e1..a3d2a1c7a51f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.Logging +import org.apache.spark.Logging private[sql] object JsonRDD extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 0995a4eb6299f..f513eae9c2d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -32,8 +32,6 @@ import org.apache.spark.annotation.DeveloperApi */ package object sql { - protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging - /** * :: DeveloperApi :: * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 829342215e691..75f653f3280bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -22,7 +22,7 @@ import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -166,7 +166,7 @@ class ColumnTypeSuite extends FunSuite with Logging { buffer.rewind() seq.foreach { expected => - logger.info("buffer = " + buffer + ", expected = " + expected) + logInfo("buffer = " + buffer + ", expected = " + expected) val extracted = columnType.extract(buffer) assert( expected === extracted, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index ddbc2a79fb512..08d3f983d9e71 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hive.service.cli.thrift.ThriftBinaryCLIService import org.apache.hive.service.server.{HiveServer2, ServerOptionsProcessor} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ @@ -40,7 +40,7 @@ private[hive] object HiveThriftServer2 extends Logging { val optionsProcessor = new ServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { - logger.warn("Error starting HiveThriftServer2 with given arguments") + logWarning("Error starting HiveThriftServer2 with given arguments") System.exit(-1) } @@ -49,12 +49,12 @@ private[hive] object HiveThriftServer2 extends Logging { // Set all properties specified via command line. val hiveConf: HiveConf = ss.getConf hiveConf.getAllProperties.toSeq.sortBy(_._1).foreach { case (k, v) => - logger.debug(s"HiveConf var: $k=$v") + logDebug(s"HiveConf var: $k=$v") } SessionState.start(ss) - logger.info("Starting SparkContext") + logInfo("Starting SparkContext") SparkSQLEnv.init() SessionState.start(ss) @@ -70,10 +70,10 @@ private[hive] object HiveThriftServer2 extends Logging { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) server.init(hiveConf) server.start() - logger.info("HiveThriftServer2 started") + logInfo("HiveThriftServer2 started") } catch { case e: Exception => - logger.error("Error starting HiveThriftServer2", e) + logError("Error starting HiveThriftServer2", e) System.exit(-1) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index cb17d7ce58ea0..4d0c506c5a397 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -37,7 +37,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.shims.ShimLoader import org.apache.thrift.transport.TSocket -import org.apache.spark.sql.Logging +import org.apache.spark.Logging private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index a56b19a4bcda0..d362d599d08ca 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveContext) @@ -40,7 +40,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo private def getResultSetSchema(query: context.QueryExecution): Schema = { val analyzed = query.analyzed - logger.debug(s"Result Schema: ${analyzed.output}") + logDebug(s"Result Schema: ${analyzed.output}") if (analyzed.output.size == 0) { new Schema(new FieldSchema("Response code", "string", "") :: Nil, null) } else { @@ -61,7 +61,7 @@ private[hive] class SparkSQLDriver(val context: HiveContext = SparkSQLEnv.hiveCo new CommandProcessorResponse(0) } catch { case cause: Throwable => - logger.error(s"Failed in [$command]", cause) + logError(s"Failed in [$command]", cause) new CommandProcessorResponse(-3, ExceptionUtils.getFullStackTrace(cause), null) } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 451c3bd7b9352..582264eb59f83 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.hive.thriftserver import org.apache.hadoop.hive.ql.session.SessionState import org.apache.spark.scheduler.{SplitInfo, StatsReportListener} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{SparkConf, SparkContext} /** A singleton object for the master program. The slaves should not access this. */ private[hive] object SparkSQLEnv extends Logging { - logger.debug("Initializing SparkSQLEnv") + logDebug("Initializing SparkSQLEnv") var hiveContext: HiveContext = _ var sparkContext: SparkContext = _ @@ -47,7 +47,7 @@ private[hive] object SparkSQLEnv extends Logging { /** Cleans up and shuts down the Spark SQL environments. */ def stop() { - logger.debug("Shutting down Spark SQL Environment") + logDebug("Shutting down Spark SQL Environment") // Stop the SparkContext if (SparkSQLEnv.sparkContext != null) { sparkContext.stop() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index a4e1f3e762e89..d4dadfd21d13f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -30,10 +30,11 @@ import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.hive.thriftserver.ReflectionUtils import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} -import org.apache.spark.sql.{Logging, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{SchemaRDD, Row => SparkRow} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. @@ -55,7 +56,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage def close(): Unit = { // RDDs will be cleaned automatically upon garbage collection. - logger.debug("CLOSING") + logDebug("CLOSING") } def getNextRowSet(order: FetchOrientation, maxRowsL: Long): RowSet = { @@ -112,7 +113,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def getResultSetSchema: TableSchema = { - logger.warn(s"Result Schema: ${result.queryExecution.analyzed.output}") + logWarning(s"Result Schema: ${result.queryExecution.analyzed.output}") if (result.queryExecution.analyzed.output.size == 0) { new TableSchema(new FieldSchema("Result", "string", "") :: Nil) } else { @@ -124,11 +125,11 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage } def run(): Unit = { - logger.info(s"Running query '$statement'") + logInfo(s"Running query '$statement'") setState(OperationState.RUNNING) try { result = hiveContext.hql(statement) - logger.debug(result.queryExecution.toString()) + logDebug(result.queryExecution.toString()) val groupId = round(random * 1000000).toString hiveContext.sparkContext.setJobGroup(groupId, statement) iter = result.queryExecution.toRdd.toLocalIterator @@ -138,7 +139,7 @@ class SparkSQLOperationManager(hiveContext: HiveContext) extends OperationManage // Actually do need to catch Throwable as some failures don't inherit from Exception and // HiveServer will silently swallow them. case e: Throwable => - logger.error("Error executing query:",e) + logError("Error executing query:",e) throw new HiveSQLException(e.toString) } setState(OperationState.FINISHED) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala index fe3403b3292ec..b7b7c9957ac34 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala @@ -27,7 +27,7 @@ import java.sql.{Connection, DriverManager, Statement} import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.util.getTempFilePath /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7e3b8727bebed..2c7270d9f83a9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -207,7 +207,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } catch { case e: Exception => - logger.error( + logError( s""" |====================== |HIVE FAILURE OUTPUT diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index fa4e78439c26c..df3604439e483 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -28,7 +28,8 @@ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.{SQLContext, Logging} +import org.apache.spark.Logging +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, Catalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index c50e8c4b5c5d3..728452a25a00e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -148,7 +148,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { describedTables ++ logical.collect { case UnresolvedRelation(databaseName, name, _) => name } val referencedTestTables = referencedTables.filter(testTables.contains) - logger.debug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") referencedTestTables.foreach(loadTestTable) // Proceed with analysis. analyzer(logical) @@ -273,7 +273,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infite mutually recursive table loading. loadedTables += name - logger.info(s"Loading test table $name") + logInfo(s"Loading test table $name") val createCmds = testTables.get(name).map(_.commands).getOrElse(sys.error(s"Unknown test table $name")) createCmds.foreach(_()) @@ -312,7 +312,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadedTables.clear() catalog.client.getAllTables("default").foreach { t => - logger.debug(s"Deleting table $t") + logDebug(s"Deleting table $t") val table = catalog.client.getTable("default", t) catalog.client.getIndexes("default", t, 255).foreach { index => @@ -325,7 +325,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { } catalog.client.getAllDatabases.filterNot(_ == "default").foreach { db => - logger.debug(s"Dropping Database: $db") + logDebug(s"Dropping Database: $db") catalog.client.dropDatabase(db, true, false, true) } @@ -347,7 +347,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { loadTestTable("srcpart") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Failed to reset TestDB state. $e") + logError(s"FATAL ERROR: Failed to reset TestDB state. $e") // At this point there is really no reason to continue, but the test framework traps exits. // So instead we just pause forever so that at least the developer can see where things // started to go wrong. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 7582b4743d404..d181921269b56 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ @@ -119,7 +119,7 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[ sys.error(s"No matching wrapper found, options: ${argClass.getConstructors.toSeq}.")) (a: Any) => { - logger.debug( + logDebug( s"Wrapping $a of type ${if (a == null) "null" else a.getClass.getName} using $constructor.") // We must make sure that primitives get boxed java style. if (a == null) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 6c8fe4b196dea..83cfbc6b4a002 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -21,7 +21,7 @@ import java.io._ import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} -import org.apache.spark.sql.Logging +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand => LogicalNativeCommand} @@ -197,7 +197,7 @@ abstract class HiveComparisonTest // If test sharding is enable, skip tests that are not in the correct shard. shardInfo.foreach { case (shardId, numShards) if testCaseName.hashCode % numShards != shardId => return - case (shardId, _) => logger.debug(s"Shard $shardId includes test '$testCaseName'") + case (shardId, _) => logDebug(s"Shard $shardId includes test '$testCaseName'") } // Skip tests found in directories specified by user. @@ -213,13 +213,13 @@ abstract class HiveComparisonTest .map(new File(_, testCaseName)) .filter(_.exists) if (runOnlyDirectories.nonEmpty && runIndicators.isEmpty) { - logger.debug( + logDebug( s"Skipping test '$testCaseName' not found in ${runOnlyDirectories.map(_.getCanonicalPath)}") return } test(testCaseName) { - logger.debug(s"=== HIVE TEST: $testCaseName ===") + logDebug(s"=== HIVE TEST: $testCaseName ===") // Clear old output for this testcase. outputDirectories.map(new File(_, testCaseName)).filter(_.exists()).foreach(_.delete()) @@ -235,7 +235,7 @@ abstract class HiveComparisonTest .filterNot(_ contains "hive.outerjoin.supports.filters") if (allQueries != queryList) - logger.warn(s"Simplifications made on unsupported operations for test $testCaseName") + logWarning(s"Simplifications made on unsupported operations for test $testCaseName") lazy val consoleTestCase = { val quotes = "\"\"\"" @@ -257,11 +257,11 @@ abstract class HiveComparisonTest } val hiveCachedResults = hiveCacheFiles.flatMap { cachedAnswerFile => - logger.debug(s"Looking for cached answer file $cachedAnswerFile.") + logDebug(s"Looking for cached answer file $cachedAnswerFile.") if (cachedAnswerFile.exists) { Some(fileToString(cachedAnswerFile)) } else { - logger.debug(s"File $cachedAnswerFile not found") + logDebug(s"File $cachedAnswerFile not found") None } }.map { @@ -272,7 +272,7 @@ abstract class HiveComparisonTest val hiveResults: Seq[Seq[String]] = if (hiveCachedResults.size == queryList.size) { - logger.info(s"Using answer cache for test: $testCaseName") + logInfo(s"Using answer cache for test: $testCaseName") hiveCachedResults } else { @@ -287,7 +287,7 @@ abstract class HiveComparisonTest if (installHooksCommand.findAllMatchIn(queryString).nonEmpty) sys.error("hive exec hooks not supported for tests.") - logger.warn(s"Running query ${i+1}/${queryList.size} with hive.") + logWarning(s"Running query ${i+1}/${queryList.size} with hive.") // Analyze the query with catalyst to ensure test tables are loaded. val answer = hiveQuery.analyzed match { case _: ExplainCommand => Nil // No need to execute EXPLAIN queries as we don't check the output. @@ -351,7 +351,7 @@ abstract class HiveComparisonTest val resultComparison = sideBySide(hivePrintOut, catalystPrintOut).mkString("\n") if (recomputeCache) { - logger.warn(s"Clearing cache files for failed test $testCaseName") + logWarning(s"Clearing cache files for failed test $testCaseName") hiveCacheFiles.foreach(_.delete()) } @@ -380,7 +380,7 @@ abstract class HiveComparisonTest TestHive.runSqlHive("SELECT key FROM src") } catch { case e: Exception => - logger.error(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") + logError(s"FATAL ERROR: Canary query threw $e This implies that the testing environment has likely been corrupted.") // The testing setup traps exits so wait here for a long time so the developer can see when things started // to go wrong. Thread.sleep(1000000) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala index 50ab71a9003d3..02518d516261b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQueryFileTest.scala @@ -53,7 +53,7 @@ abstract class HiveQueryFileTest extends HiveComparisonTest { testCases.sorted.foreach { case (testCaseName, testCaseFile) => if (blackList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_)) { - logger.debug(s"Blacklisted test skipped $testCaseName") + logDebug(s"Blacklisted test skipped $testCaseName") } else if (realWhiteList.map(_.r.pattern.matcher(testCaseName).matches()).reduceLeft(_||_) || runAll) { // Build a test case and submit it to scala test framework... val queriesString = fileToString(testCaseFile) From 158ad0bba9382fd494b4789b5628a9cec00cfa19 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 2 Aug 2014 16:33:48 -0700 Subject: [PATCH 06/13] [SPARK-2097][SQL] UDF Support This patch adds the ability to register lambda functions written in Python, Java or Scala as UDFs for use in SQL or HiveQL. Scala: ```scala registerFunction("strLenScala", (_: String).length) sql("SELECT strLenScala('test')") ``` Python: ```python sqlCtx.registerFunction("strLenPython", lambda x: len(x), IntegerType()) sqlCtx.sql("SELECT strLenPython('test')") ``` Java: ```java sqlContext.registerFunction("stringLengthJava", new UDF1() { Override public Integer call(String str) throws Exception { return str.length(); } }, DataType.IntegerType); sqlContext.sql("SELECT stringLengthJava('test')"); ``` Author: Michael Armbrust Closes #1063 from marmbrus/udfs and squashes the following commits: 9eda0fe [Michael Armbrust] newline 747c05e [Michael Armbrust] Add some scala UDF tests. d92727d [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 005d684 [Michael Armbrust] Fix naming and formatting. d14dac8 [Michael Armbrust] Fix last line of autogened java files. 8135c48 [Michael Armbrust] Move UDF unit tests to pyspark. 40b0ffd [Michael Armbrust] Merge remote-tracking branch 'apache/master' into udfs 6a36890 [Michael Armbrust] Switch logging so that SQLContext can be serializable. 7a83101 [Michael Armbrust] Drop toString 795fd15 [Michael Armbrust] Try to avoid capturing SQLContext. e54fb45 [Michael Armbrust] Docs and tests. 437cbe3 [Michael Armbrust] Update use of dataTypes, fix some python tests, address review comments. 01517d6 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 8e6c932 [Michael Armbrust] WIP 3f96a52 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into udfs 6237c8d [Michael Armbrust] WIP 2766f0b [Michael Armbrust] Move udfs support to SQL from hive. Add support for Java UDFs. 0f7d50c [Michael Armbrust] Draft of native Spark SQL UDFs for Scala and Python. --- python/pyspark/sql.py | 39 ++- .../catalyst/analysis/FunctionRegistry.scala | 32 ++ .../sql/catalyst/expressions/ScalaUdf.scala | 307 ++++++++++++++++++ .../org/apache/spark/sql/api/java/UDF1.java | 32 ++ .../org/apache/spark/sql/api/java/UDF10.java | 32 ++ .../org/apache/spark/sql/api/java/UDF11.java | 32 ++ .../org/apache/spark/sql/api/java/UDF12.java | 32 ++ .../org/apache/spark/sql/api/java/UDF13.java | 32 ++ .../org/apache/spark/sql/api/java/UDF14.java | 32 ++ .../org/apache/spark/sql/api/java/UDF15.java | 32 ++ .../org/apache/spark/sql/api/java/UDF16.java | 32 ++ .../org/apache/spark/sql/api/java/UDF17.java | 32 ++ .../org/apache/spark/sql/api/java/UDF18.java | 32 ++ .../org/apache/spark/sql/api/java/UDF19.java | 32 ++ .../org/apache/spark/sql/api/java/UDF2.java | 32 ++ .../org/apache/spark/sql/api/java/UDF20.java | 32 ++ .../org/apache/spark/sql/api/java/UDF21.java | 32 ++ .../org/apache/spark/sql/api/java/UDF22.java | 32 ++ .../org/apache/spark/sql/api/java/UDF3.java | 32 ++ .../org/apache/spark/sql/api/java/UDF4.java | 32 ++ .../org/apache/spark/sql/api/java/UDF5.java | 32 ++ .../org/apache/spark/sql/api/java/UDF6.java | 32 ++ .../org/apache/spark/sql/api/java/UDF7.java | 32 ++ .../org/apache/spark/sql/api/java/UDF8.java | 32 ++ .../org/apache/spark/sql/api/java/UDF9.java | 32 ++ .../org/apache/spark/sql/SQLContext.scala | 11 +- .../apache/spark/sql/UdfRegistration.scala | 196 +++++++++++ .../spark/sql/api/java/JavaSQLContext.scala | 5 +- .../spark/sql/api/java/UDFRegistration.scala | 252 ++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 2 + .../spark/sql/execution/pythonUdfs.scala | 177 ++++++++++ .../spark/sql/api/java/JavaAPISuite.java | 90 +++++ .../apache/spark/sql/InsertIntoSuite.scala | 2 +- .../scala/org/apache/spark/sql/UDFSuite.scala | 36 ++ .../apache/spark/sql/hive/HiveContext.scala | 13 +- .../org/apache/spark/sql/hive/TestHive.scala | 4 +- .../org/apache/spark/sql/hive/hiveUdfs.scala | 6 +- .../org/apache/spark/sql/QueryTest.scala | 4 +- 38 files changed, 1861 insertions(+), 19 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f840475ffaf70..e7c35ac1ffe02 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -28,9 +28,13 @@ from operator import itemgetter from pyspark.rdd import RDD, PipelinedRDD -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer + +from itertools import chain, ifilter, imap from py4j.protocol import Py4JError +from py4j.java_collections import ListConverter, MapConverter + __all__ = [ "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", @@ -932,6 +936,39 @@ def _ssql_ctx(self): self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc()) return self._scala_SQLContext + def registerFunction(self, name, f, returnType=StringType()): + """Registers a lambda function as a UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + + >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) + >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() + [Row(c0=u'4')] + >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) + >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() + [Row(c0=4)] + >>> sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) + >>> sqlCtx.sql("SELECT twoArgs('test', 1)").collect() + [Row(c0=5)] + """ + func = lambda _, it: imap(lambda x: f(*x), it) + command = (func, + BatchedSerializer(PickleSerializer(), 1024), + BatchedSerializer(PickleSerializer(), 1024)) + env = MapConverter().convert(self._sc.environment, + self._sc._gateway._gateway_client) + includes = ListConverter().convert(self._sc._python_includes, + self._sc._gateway._gateway_client) + self._ssql_ctx.registerPython(name, + bytearray(CloudPickleSerializer().dumps(command)), + env, + includes, + self._sc.pythonExec, + self._sc._javaAccumulator, + str(returnType)) + def inferSchema(self, rdd): """Infer and apply a schema to an RDD of L{Row}s. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c0255701b7ba5..760c49fbca4a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -18,17 +18,49 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Expression +import scala.collection.mutable /** A catalog for looking up user defined functions, used by an [[Analyzer]]. */ trait FunctionRegistry { + type FunctionBuilder = Seq[Expression] => Expression + + def registerFunction(name: String, builder: FunctionBuilder): Unit + def lookupFunction(name: String, children: Seq[Expression]): Expression } +trait OverrideFunctionRegistry extends FunctionRegistry { + + val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + + def registerFunction(name: String, builder: FunctionBuilder) = { + functionBuilders.put(name, builder) + } + + abstract override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders.get(name).map(_(children)).getOrElse(super.lookupFunction(name,children)) + } +} + +class SimpleFunctionRegistry extends FunctionRegistry { + val functionBuilders = new mutable.HashMap[String, FunctionBuilder]() + + def registerFunction(name: String, builder: FunctionBuilder) = { + functionBuilders.put(name, builder) + } + + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionBuilders(name)(children) + } +} + /** * A trivial catalog that returns an error when a function is requested. Used for testing when all * functions are already filled in and the analyser needs only to resolve attribute references. */ object EmptyFunctionRegistry extends FunctionRegistry { + def registerFunction(name: String, builder: FunctionBuilder) = ??? + def lookupFunction(name: String, children: Seq[Expression]): Expression = { throw new UnsupportedOperationException } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index acddf5e9c7004..95633dd0c9870 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,6 +27,22 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi def references = children.flatMap(_.references).toSet def nullable = true + /** This method has been generated by this script + + (1 to 22).map { x => + val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) + val evals = (0 to x - 1).map(x => s"children($x).eval(input)").reduce(_ + ",\n " + _) + + s""" + case $x => + function.asInstanceOf[($anys) => Any]( + $evals) + """ + } + + */ + + // scalastyle:off override def eval(input: Row): Any = { children.size match { case 0 => function.asInstanceOf[() => Any]() @@ -35,6 +51,297 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi function.asInstanceOf[(Any, Any) => Any]( children(0).eval(input), children(1).eval(input)) + case 3 => + function.asInstanceOf[(Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input)) + case 4 => + function.asInstanceOf[(Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input)) + case 5 => + function.asInstanceOf[(Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input)) + case 6 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input)) + case 7 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input)) + case 8 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input)) + case 9 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input)) + case 10 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input)) + case 11 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input)) + case 12 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input)) + case 13 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input)) + case 14 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input)) + case 15 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input)) + case 16 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input)) + case 17 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input)) + case 18 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input)) + case 19 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input)) + case 20 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input)) + case 21 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input), + children(20).eval(input)) + case 22 => + function.asInstanceOf[(Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any) => Any]( + children(0).eval(input), + children(1).eval(input), + children(2).eval(input), + children(3).eval(input), + children(4).eval(input), + children(5).eval(input), + children(6).eval(input), + children(7).eval(input), + children(8).eval(input), + children(9).eval(input), + children(10).eval(input), + children(11).eval(input), + children(12).eval(input), + children(13).eval(input), + children(14).eval(input), + children(15).eval(input), + children(16).eval(input), + children(17).eval(input), + children(18).eval(input), + children(19).eval(input), + children(20).eval(input), + children(21).eval(input)) } + // scalastyle:on } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java new file mode 100644 index 0000000000000..ef959e35e1027 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF1.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 1 arguments. + */ +public interface UDF1 extends Serializable { + public R call(T1 t1) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java new file mode 100644 index 0000000000000..96ab3a96c3d5e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF10.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 10 arguments. + */ +public interface UDF10 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java new file mode 100644 index 0000000000000..58ae8edd6d817 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF11.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 11 arguments. + */ +public interface UDF11 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java new file mode 100644 index 0000000000000..d9da0f6eddd94 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF12.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 12 arguments. + */ +public interface UDF12 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java new file mode 100644 index 0000000000000..095fc1a8076b5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF13.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 13 arguments. + */ +public interface UDF13 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java new file mode 100644 index 0000000000000..eb27eaa180086 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF14.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 14 arguments. + */ +public interface UDF14 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java new file mode 100644 index 0000000000000..1fbcff56332b6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF15.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 15 arguments. + */ +public interface UDF15 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java new file mode 100644 index 0000000000000..1133561787a69 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF16.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 16 arguments. + */ +public interface UDF16 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java new file mode 100644 index 0000000000000..dfae7922c9b63 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF17.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 17 arguments. + */ +public interface UDF17 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java new file mode 100644 index 0000000000000..e9d1c6d52d4ea --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF18.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 18 arguments. + */ +public interface UDF18 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java new file mode 100644 index 0000000000000..46b9d2d3c9457 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF19.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 19 arguments. + */ +public interface UDF19 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java new file mode 100644 index 0000000000000..cd3fde8da419e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF2.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 2 arguments. + */ +public interface UDF2 extends Serializable { + public R call(T1 t1, T2 t2) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java new file mode 100644 index 0000000000000..113d3d26be4a7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF20.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 20 arguments. + */ +public interface UDF20 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java new file mode 100644 index 0000000000000..74118f2cf8da7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF21.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 21 arguments. + */ +public interface UDF21 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java new file mode 100644 index 0000000000000..0e7cc40be45ec --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF22.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 22 arguments. + */ +public interface UDF22 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9, T10 t10, T11 t11, T12 t12, T13 t13, T14 t14, T15 t15, T16 t16, T17 t17, T18 t18, T19 t19, T20 t20, T21 t21, T22 t22) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java new file mode 100644 index 0000000000000..6a880f16be47a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF3.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 3 arguments. + */ +public interface UDF3 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java new file mode 100644 index 0000000000000..fcad2febb18e6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF4.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 4 arguments. + */ +public interface UDF4 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java new file mode 100644 index 0000000000000..ce0cef43a2144 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF5.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 5 arguments. + */ +public interface UDF5 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java new file mode 100644 index 0000000000000..f56b806684e61 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF6.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 6 arguments. + */ +public interface UDF6 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java new file mode 100644 index 0000000000000..25bd6d3241bd4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF7.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 7 arguments. + */ +public interface UDF7 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java new file mode 100644 index 0000000000000..a3b7ac5f94ce7 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF8.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 8 arguments. + */ +public interface UDF8 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java new file mode 100644 index 0000000000000..205e72a1522fc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/UDF9.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +// ************************************************** +// THIS FILE IS AUTOGENERATED BY CODE IN +// org.apache.spark.sql.api.java.FunctionRegistration +// ************************************************** + +/** + * A Spark SQL UDF that has 9 arguments. + */ +public interface UDF9 extends Serializable { + public R call(T1 t1, T2 t2, T3 t3, T4 t4, T5 t5, T6 t6, T7 t7, T8 t8, T9 t9) throws Exception; +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 00dd34aabc389..33931e5d996f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -48,18 +48,23 @@ import org.apache.spark.{Logging, SparkContext} */ @AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) - extends Logging + extends org.apache.spark.Logging with SQLConf with ExpressionConversions + with UDFRegistration with Serializable { self => @transient protected[sql] lazy val catalog: Catalog = new SimpleCatalog(true) + + @transient + protected[sql] lazy val functionRegistry: FunctionRegistry = new SimpleFunctionRegistry + @transient protected[sql] lazy val analyzer: Analyzer = - new Analyzer(catalog, EmptyFunctionRegistry, caseSensitive = true) + new Analyzer(catalog, functionRegistry, caseSensitive = true) @transient protected[sql] val optimizer = Optimizer @transient @@ -379,7 +384,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected abstract class QueryExecution { def logical: LogicalPlan - lazy val analyzed = analyzer(logical) + lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... lazy val sparkPlan = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala new file mode 100644 index 0000000000000..0b48e9e659faa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.{List => JList, Map => JMap} + +import org.apache.spark.Accumulator +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.execution.PythonUDF + +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +/** + * Functions for registering scala lambda functions as UDFs in a SQLContext. + */ +protected[sql] trait UDFRegistration { + self: SQLContext => + + private[spark] def registerPython( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + accumulator: Accumulator[JList[Array[Byte]]], + stringDataType: String): Unit = { + log.debug( + s""" + | Registering new PythonUDF: + | name: $name + | command: ${command.toSeq} + | envVars: $envVars + | pythonIncludes: $pythonIncludes + | pythonExec: $pythonExec + | dataType: $stringDataType + """.stripMargin) + + + val dataType = parseDataType(stringDataType) + + def builder(e: Seq[Expression]) = + PythonUDF( + name, + command, + envVars, + pythonIncludes, + pythonExec, + accumulator, + dataType, + e) + + functionRegistry.registerFunction(name, builder) + } + + /** registerFunction 1-22 were generated by this script + + (1 to 22).map { x => + val types = (1 to x).map(x => "_").reduce(_ + ", " + _) + s""" + def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { + def builder(e: Seq[Expression]) = + ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + """ + } + */ + + // scalastyle:off + def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function2[_, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function3[_, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function4[_, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function5[_, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function6[_, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function7[_, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function8[_, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function9[_, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function10[_, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function11[_, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function12[_, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + + def registerFunction[T: TypeTag](name: String, func: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor(typeTag[T]).dataType, e) + functionRegistry.registerFunction(name, builder) + } + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 809dd038f94aa..ae45193ed15d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -28,14 +28,13 @@ import org.apache.spark.sql.{SQLContext, StructType => SStructType} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} -import org.apache.spark.sql.types.util.DataTypeConversions -import DataTypeConversions.asScalaDataType; +import org.apache.spark.sql.types.util.DataTypeConversions.asScalaDataType import org.apache.spark.util.Utils /** * The entry point for executing Spark SQL queries from a Java program. */ -class JavaSQLContext(val sqlContext: SQLContext) { +class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { def this(sparkContext: JavaSparkContext) = this(new SQLContext(sparkContext.sc)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala new file mode 100644 index 0000000000000..158f26e3d445f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/UDFRegistration.scala @@ -0,0 +1,252 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.api.java + +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.types.util.DataTypeConversions._ + +/** + * A collection of functions that allow Java users to register UDFs. In order to handle functions + * of varying airities with minimal boilerplate for our users, we generate classes and functions + * for each airity up to 22. The code for this generation can be found in comments in this trait. + */ +private[java] trait UDFRegistration { + self: JavaSQLContext => + + /* The following functions and required interfaces are generated with these code fragments: + + (1 to 22).foreach { i => + val extTypeArgs = (1 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (1 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs, Any]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + println(s""" + |def registerFunction( + | name: String, f: UDF$i[$extTypeArgs, _], @transient dataType: DataType) = { + | val scalaType = asScalaDataType(dataType) + | sqlContext.functionRegistry.registerFunction( + | name, + | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), scalaType, e)) + |} + """.stripMargin) + } + + import java.io.File + import org.apache.spark.sql.catalyst.util.stringToFile + val directory = new File("sql/core/src/main/java/org/apache/spark/sql/api/java/") + (1 to 22).foreach { i => + val typeArgs = (1 to i).map(i => s"T$i").mkString(", ") + val args = (1 to i).map(i => s"T$i t$i").mkString(", ") + + val contents = + s"""/* + | * Licensed to the Apache Software Foundation (ASF) under one or more + | * contributor license agreements. See the NOTICE file distributed with + | * this work for additional information regarding copyright ownership. + | * The ASF licenses this file to You under the Apache License, Version 2.0 + | * (the "License"); you may not use this file except in compliance with + | * the License. You may obtain a copy of the License at + | * + | * http://www.apache.org/licenses/LICENSE-2.0 + | * + | * Unless required by applicable law or agreed to in writing, software + | * distributed under the License is distributed on an "AS IS" BASIS, + | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + | * See the License for the specific language governing permissions and + | * limitations under the License. + | */ + | + |package org.apache.spark.sql.api.java; + | + |import java.io.Serializable; + | + |// ************************************************** + |// THIS FILE IS AUTOGENERATED BY CODE IN + |// org.apache.spark.sql.api.java.FunctionRegistration + |// ************************************************** + | + |/** + | * A Spark SQL UDF that has $i arguments. + | */ + |public interface UDF$i<$typeArgs, R> extends Serializable { + | public R call($args) throws Exception; + |} + |""".stripMargin + + stringToFile(new File(directory, s"UDF$i.java"), contents) + } + + */ + + // scalastyle:off + def registerFunction(name: String, f: UDF1[_, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF2[_, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF3[_, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF4[_, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF5[_, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF6[_, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF7[_, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + def registerFunction(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], dataType: DataType) = { + val scalaType = asScalaDataType(dataType) + sqlContext.functionRegistry.registerFunction( + name, + (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), scalaType, e)) + } + + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 8bec015c7b465..f0c958fdb537f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -286,6 +286,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ExistingRdd(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil + case e @ EvaluatePython(udf, child) => + BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case SparkLogicalPlan(existingPlan) => existingPlan :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala new file mode 100644 index 0000000000000..b92091b560b1c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -0,0 +1,177 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution + +import java.util.{List => JList, Map => JMap} + +import net.razorvine.pickle.{Pickler, Unpickler} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.python.PythonRDD +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.{Accumulator, Logging => SparkLogging} + +import scala.collection.JavaConversions._ + +/** + * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. + */ +private[spark] case class PythonUDF( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType, + children: Seq[Expression]) extends Expression with SparkLogging { + + override def toString = s"PythonUDF#$name(${children.mkString(",")})" + + def nullable: Boolean = true + def references: Set[Attribute] = children.flatMap(_.references).toSet + + override def eval(input: Row) = sys.error("PythonUDFs can not be directly evaluated.") +} + +/** + * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated + * alone in a batch. + * + * This has the limitation that the input to the Python UDF is not allowed include attributes from + * multiple child operators. + */ +private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan) = plan transform { + // Skip EvaluatePython nodes. + case p: EvaluatePython => p + + case l: LogicalPlan => + // Extract any PythonUDFs from the current operator. + val udfs = l.expressions.flatMap(_.collect { case udf: PythonUDF => udf}) + if (udfs.isEmpty) { + // If there aren't any, we are done. + l + } else { + // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) + // If there is more than one, we will add another evaluation operator in a subsequent pass. + val udf = udfs.head + + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = l.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartisian product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + l.output, + l.transformExpressions { + case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute + }.withNewChildren(newChildren)) + } + } +} + +/** + * :: DeveloperApi :: + * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. + */ +@DeveloperApi +case class EvaluatePython(udf: PythonUDF, child: LogicalPlan) extends logical.UnaryNode { + val resultAttribute = AttributeReference("pythonUDF", udf.dataType, nullable=true)() + + def references = Set.empty + def output = child.output :+ resultAttribute +} + +/** + * :: DeveloperApi :: + * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. The input + * data is cached and zipped with the result of the udf evaluation. + */ +@DeveloperApi +case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + def children = child :: Nil + + def execute() = { + // TODO: Clean up after ourselves? + val childResults = child.execute().map(_.copy()).cache() + + val parent = childResults.mapPartitions { iter => + val pickle = new Pickler + val currentRow = newMutableProjection(udf.children, child.output)() + iter.grouped(1000).map { inputRows => + val toBePickled = inputRows.map(currentRow(_).toArray).toArray + pickle.dumps(toBePickled) + } + } + + val pyRDD = new PythonRDD( + parent, + udf.command, + udf.envVars, + udf.pythonIncludes, + false, + udf.pythonExec, + Seq[Broadcast[Array[Byte]]](), + udf.accumulator + ).mapPartitions { iter => + val pickle = new Unpickler + iter.flatMap { pickedResult => + val unpickledBatch = pickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]] + } + }.mapPartitions { iter => + val row = new GenericMutableRow(1) + iter.map { result => + row(0) = udf.dataType match { + case StringType => result.toString + case other => result + } + row: Row + } + } + + childResults.zip(pyRDD).mapPartitions { iter => + val joinedRow = new JoinedRow() + iter.map { + case (row, udfResult) => + joinedRow(row, udfResult) + } + } + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java new file mode 100644 index 0000000000000..a9a11285def54 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.Serializable; + +import org.apache.spark.sql.api.java.UDF1; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runners.Suite; +import org.junit.runner.RunWith; + +import org.apache.spark.api.java.JavaSparkContext; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaAPISuite implements Serializable { + private transient JavaSparkContext sc; + private transient JavaSQLContext sqlContext; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaAPISuite"); + sqlContext = new JavaSQLContext(sc); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", (String str) -> str.length(), DataType.IntegerType); + + sqlContext.registerFunction("stringLengthTest", new UDF1() { + @Override + public Integer call(String str) throws Exception { + return str.length(); + } + }, DataType.IntegerType); + + // TODO: Why do we need this cast? + Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test')").first(); + assert(result.getInt(0) == 4); + } + + @SuppressWarnings("unchecked") + @Test + public void udf2Test() { + // With Java 8 lambdas: + // sqlContext.registerFunction( + // "stringLengthTest", + // (String str1, String str2) -> str1.length() + str2.length, + // DataType.IntegerType); + + sqlContext.registerFunction("stringLengthTest", new UDF2() { + @Override + public Integer call(String str1, String str2) throws Exception { + return str1.length() + str2.length(); + } + }, DataType.IntegerType); + + // TODO: Why do we need this cast? + Row result = (Row) sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); + assert(result.getInt(0) == 9); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala index 4f0b85f26254b..23a711d08c58b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.io.File +import _root_.java.io.File /* Implicits */ import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala new file mode 100644 index 0000000000000..76aa9b0081d7e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class UDFSuite extends QueryTest { + + test("Simple UDF") { + registerFunction("strLenScala", (_: String).length) + assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + } + + test("TwoArgument UDF") { + registerFunction("strLenScala", (_: String).length + (_:Int)) + assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2c7270d9f83a9..3c70b3f0921a5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -23,7 +23,7 @@ import java.util.{ArrayList => JArrayList} import scala.collection.JavaConversions._ import scala.language.implicitConversions -import scala.reflect.runtime.universe.TypeTag +import scala.reflect.runtime.universe.{TypeTag, typeTag} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver @@ -35,8 +35,9 @@ import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} +import org.apache.spark.sql.catalyst.analysis.{OverrideFunctionRegistry, Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.ExtractPythonUdfs import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand @@ -155,10 +156,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } + // Note that HiveUDFs will be overridden by functions registered in this context. + override protected[sql] lazy val functionRegistry = + new HiveFunctionRegistry with OverrideFunctionRegistry + /* An analyzer that uses the Hive metastore. */ @transient override protected[sql] lazy val analyzer = - new Analyzer(catalog, HiveFunctionRegistry, caseSensitive = false) + new Analyzer(catalog, functionRegistry, caseSensitive = false) /** * Runs the specified SQL query using Hive. @@ -250,7 +255,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { protected[sql] abstract class QueryExecution extends super.QueryExecution { // TODO: Create mixin for the analyzer instead of overriding things here. override lazy val optimizedPlan = - optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) + optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))) override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 728452a25a00e..c605e8adcfb0f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -297,8 +297,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { def reset() { try { // HACK: Hive is too noisy by default. - org.apache.log4j.LogManager.getCurrentLoggers.foreach { logger => - logger.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + org.apache.log4j.LogManager.getCurrentLoggers.foreach { log => + log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } // It is important that we RESET first as broken hooks that might have been set could break diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index d181921269b56..179aac5cbd5cd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -34,7 +34,8 @@ import org.apache.spark.util.Utils.getContextOrSparkClassLoader /* Implicit conversions */ import scala.collection.JavaConversions._ -private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors { +private[hive] abstract class HiveFunctionRegistry + extends analysis.FunctionRegistry with HiveInspectors { def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name) @@ -92,9 +93,8 @@ private[hive] abstract class HiveUdf extends Expression with Logging with HiveFu } private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression]) - extends HiveUdf { + extends HiveUdf with HiveInspectors { - import org.apache.spark.sql.hive.HiveFunctionRegistry._ type UDFType = UDF @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index 11d8b1f0a3d96..95921c3d7ae09 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -51,9 +51,9 @@ class QueryTest extends FunSuite { fail( s""" |Exception thrown while executing query: - |${rdd.logicalPlan} + |${rdd.queryExecution} |== Exception == - |$e + |${stackTraceToString(e)} """.stripMargin) } From 198df11f1a9f419f820f47eba0e9f2ab371a824b Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 2 Aug 2014 16:48:07 -0700 Subject: [PATCH 07/13] [SPARK-2785][SQL] Remove assertions that throw when users try unsupported Hive commands. Author: Michael Armbrust Closes #1742 from marmbrus/asserts and squashes the following commits: 5182d54 [Michael Armbrust] Remove assertions that throw when users try unsupported Hive commands. --- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 3d2eb1eefaeda..bc2fefafd58c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -297,8 +297,11 @@ private[hive] object HiveQl { matches.headOption } - assert(remainingNodes.isEmpty, - s"Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}") + if (remainingNodes.nonEmpty) { + sys.error( + s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. + |You are likely trying to use an unsupported Hive feature."""".stripMargin) + } clauses } @@ -748,7 +751,10 @@ private[hive] object HiveQl { case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => - assert(other.size <= 1, s"Unhandled join child $other") + if (!(other.size <= 1)) { + sys.error(s"Unsupported join operation: $other") + } + val joinType = joinToken match { case "TOK_JOIN" => Inner case "TOK_RIGHTOUTERJOIN" => RightOuter @@ -756,7 +762,6 @@ private[hive] object HiveQl { case "TOK_FULLOUTERJOIN" => FullOuter case "TOK_LEFTSEMIJOIN" => LeftSemi } - assert(other.size <= 1, "Unhandled join clauses.") Join(nodeToRelation(relation1), nodeToRelation(relation2), joinType, From 866cf1f822cfda22294054be026ef2d96307eb75 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sat, 2 Aug 2014 17:12:49 -0700 Subject: [PATCH 08/13] [SPARK-2729][SQL] Added test case for SPARK-2729 This is a follow up of #1636. Author: Cheng Lian Closes #1738 from liancheng/test-for-spark-2729 and squashes the following commits: b13692a [Cheng Lian] Added test case for SPARK-2729 --- .../test/scala/org/apache/spark/sql/TestData.scala | 12 ++++++++++-- .../sql/columnar/InMemoryColumnarQuerySuite.scala | 12 ++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 58cee21e8ad4c..088e6e3c843aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test._ /* Implicits */ -import TestSQLContext._ +import org.apache.spark.sql.test.TestSQLContext._ case class TestData(key: Int, value: String) @@ -40,7 +42,7 @@ object TestData { LargeAndSmallInts(2147483646, 1) :: LargeAndSmallInts(3, 2) :: Nil) largeAndSmallInts.registerAsTable("largeAndSmallInts") - + case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = TestSQLContext.sparkContext.parallelize( @@ -143,4 +145,10 @@ object TestData { "2, B2, false, null" :: "3, C3, true, null" :: "4, D4, true, 2147483644" :: Nil) + + case class TimestampField(time: Timestamp) + val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => + TimestampField(new Timestamp(i)) + }) + timestamps.registerAsTable("timestamps") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index 86727b93f3659..b561b44ad7ee2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -73,4 +73,16 @@ class InMemoryColumnarQuerySuite extends QueryTest { sql("SELECT * FROM nullableRepeatedData"), nullableRepeatedData.collect().toSeq) } + + test("SPARK-2729 regression: timestamp data type") { + checkAnswer( + sql("SELECT time FROM timestamps"), + timestamps.collect().toSeq) + + TestSQLContext.cacheTable("timestamps") + + checkAnswer( + sql("SELECT time FROM timestamps"), + timestamps.collect().toSeq) + } } From d210022e96804e59e42ab902e53637e50884a9ab Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 2 Aug 2014 17:55:22 -0700 Subject: [PATCH 09/13] [SPARK-2797] [SQL] SchemaRDDs don't support unpersist() The cause is explained in https://issues.apache.org/jira/browse/SPARK-2797. Author: Yin Huai Closes #1745 from yhuai/SPARK-2797 and squashes the following commits: 7b1627d [Yin Huai] The unpersist method of the Scala RDD cannot be called without the input parameter (blocking) from PySpark. --- python/pyspark/sql.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index e7c35ac1ffe02..36e50e49c9a9c 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -1589,9 +1589,9 @@ def persist(self, storageLevel): self._jschema_rdd.persist(javaStorageLevel) return self - def unpersist(self): + def unpersist(self, blocking=True): self.is_cached = False - self._jschema_rdd.unpersist() + self._jschema_rdd.unpersist(blocking) return self def checkpoint(self): From 1a8043739dc1d9435def6ea3c6341498ba52b708 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Sat, 2 Aug 2014 18:27:04 -0700 Subject: [PATCH 10/13] [SPARK-2739][SQL] Rename registerAsTable to registerTempTable There have been user complaints that the difference between `registerAsTable` and `saveAsTable` is too subtle. This PR addresses this by renaming `registerAsTable` to `registerTempTable`, which more clearly reflects what is happening. `registerAsTable` remains, but will cause a deprecation warning. Author: Michael Armbrust Closes #1743 from marmbrus/registerTempTable and squashes the following commits: d031348 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into registerTempTable 4dff086 [Michael Armbrust] Fix .java files too 89a2f12 [Michael Armbrust] Merge remote-tracking branch 'apache/master' into registerTempTable 0b7b71e [Michael Armbrust] Rename registerAsTable to registerTempTable --- .../sbt_app_sql/src/main/scala/SqlApp.scala | 2 +- docs/sql-programming-guide.md | 18 ++++++------ .../spark/examples/sql/JavaSparkSQL.java | 8 +++--- .../spark/examples/sql/RDDRelation.scala | 4 +-- .../examples/sql/hive/HiveFromSpark.scala | 2 +- python/pyspark/sql.py | 12 +++++--- .../org/apache/spark/sql/SQLContext.scala | 4 +-- .../org/apache/spark/sql/SchemaRDD.scala | 2 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 5 +++- .../spark/sql/api/java/JavaSQLContext.scala | 2 +- .../sql/api/java/JavaApplySchemaSuite.java | 6 ++-- .../apache/spark/sql/CachedTableSuite.scala | 2 +- .../apache/spark/sql/InsertIntoSuite.scala | 4 +-- .../org/apache/spark/sql/JoinSuite.scala | 4 +-- .../org/apache/spark/sql/SQLQuerySuite.scala | 6 ++-- .../sql/ScalaReflectionRelationSuite.scala | 8 +++--- .../scala/org/apache/spark/sql/TestData.scala | 28 +++++++++---------- .../spark/sql/api/java/JavaSQLSuite.scala | 10 +++---- .../org/apache/spark/sql/json/JsonSuite.scala | 22 +++++++-------- .../spark/sql/parquet/ParquetQuerySuite.scala | 26 ++++++++--------- .../sql/hive/InsertIntoHiveTableSuite.scala | 2 +- .../sql/hive/api/java/JavaHiveQLSuite.scala | 4 +-- .../sql/hive/execution/HiveQuerySuite.scala | 6 ++-- .../hive/execution/HiveResolutionSuite.scala | 4 +-- .../spark/sql/parquet/HiveParquetSuite.scala | 8 +++--- 25 files changed, 103 insertions(+), 96 deletions(-) diff --git a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala index 50af90c213b5a..d888de929fdda 100644 --- a/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala +++ b/dev/audit-release/sbt_app_sql/src/main/scala/SqlApp.scala @@ -38,7 +38,7 @@ object SparkSqlExample { import sqlContext._ val people = sc.makeRDD(1 to 100, 10).map(x => Person(s"Name$x", x)) - people.registerAsTable("people") + people.registerTempTable("people") val teenagers = sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") val teenagerNames = teenagers.map(t => "Name: " + t(0)).collect() teenagerNames.foreach(println) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 7261badd411a9..0465468084cee 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -142,7 +142,7 @@ case class Person(name: String, age: Int) // Create an RDD of Person objects and register it as a table. val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Person(p(0), p(1).trim.toInt)) -people.registerAsTable("people") +people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -210,7 +210,7 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m // Apply a schema to an RDD of JavaBeans and register it as a table. JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class); -schemaPeople.registerAsTable("people"); +schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -248,7 +248,7 @@ people = parts.map(lambda p: {"name": p[0], "age": int(p[1])}) # In future versions of PySpark we would like to add support for registering RDDs with other # datatypes as tables schemaPeople = sqlContext.inferSchema(people) -schemaPeople.registerAsTable("people") +schemaPeople.registerTempTable("people") # SQL can be run over SchemaRDDs that have been registered as a table. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -292,7 +292,7 @@ people.saveAsParquetFile("people.parquet") val parquetFile = sqlContext.parquetFile("people.parquet") //Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile") +parquetFile.registerTempTable("parquetFile") val teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenagers.map(t => "Name: " + t(0)).collect().foreach(println) {% endhighlight %} @@ -314,7 +314,7 @@ schemaPeople.saveAsParquetFile("people.parquet"); JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile"); +parquetFile.registerTempTable("parquetFile"); JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); List teenagerNames = teenagers.map(new Function() { public String call(Row row) { @@ -340,7 +340,7 @@ schemaPeople.saveAsParquetFile("people.parquet") parquetFile = sqlContext.parquetFile("people.parquet") # Parquet files can also be registered as tables and then used in SQL statements. -parquetFile.registerAsTable("parquetFile"); +parquetFile.registerTempTable("parquetFile"); teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19") teenNames = teenagers.map(lambda p: "Name: " + p.name) for teenName in teenNames.collect(): @@ -378,7 +378,7 @@ people.printSchema() // |-- name: StringType // Register this SchemaRDD as a table. -people.registerAsTable("people") +people.registerTempTable("people") // SQL statements can be run by using the sql methods provided by sqlContext. val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") @@ -416,7 +416,7 @@ people.printSchema(); // |-- name: StringType // Register this JavaSchemaRDD as a table. -people.registerAsTable("people"); +people.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlContext. JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -455,7 +455,7 @@ people.printSchema() # |-- name: StringType # Register this SchemaRDD as a table. -people.registerAsTable("people") +people.registerTempTable("people") # SQL statements can be run by using the sql methods provided by sqlContext. teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19") diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index 607df3eddd550..898297dc658ba 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -74,7 +74,7 @@ public Person call(String line) throws Exception { // Apply a schema to an RDD of Java Beans and register it as a table. JavaSchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); - schemaPeople.registerAsTable("people"); + schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. JavaSchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -100,7 +100,7 @@ public String call(Row row) { JavaSchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. - parquetFile.registerAsTable("parquetFile"); + parquetFile.registerTempTable("parquetFile"); JavaSchemaRDD teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.map(new Function() { @@ -128,7 +128,7 @@ public String call(Row row) { // |-- name: StringType // Register this JavaSchemaRDD as a table. - peopleFromJsonFile.registerAsTable("people"); + peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. JavaSchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); @@ -158,7 +158,7 @@ public String call(Row row) { // | |-- state: StringType // |-- name: StringType - peopleFromJsonRDD.registerAsTable("people2"); + peopleFromJsonRDD.registerTempTable("people2"); JavaSchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.map(new Function() { diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 63db688bfb8c0..d56d64c564200 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -36,7 +36,7 @@ object RDDRelation { val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) // Any RDD containing case classes can be registered as a table. The schema of the table is // automatically inferred using scala reflection. - rdd.registerAsTable("records") + rdd.registerTempTable("records") // Once tables have been registered, you can run SQL queries over them. println("Result of SELECT *:") @@ -66,7 +66,7 @@ object RDDRelation { parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) // These files can also be registered as tables. - parquetFile.registerAsTable("parquetFile") + parquetFile.registerTempTable("parquetFile") sql("SELECT * FROM parquetFile").collect().foreach(println) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index dc5290fb4f10e..12530c8490b09 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -56,7 +56,7 @@ object HiveFromSpark { // You can also register RDDs as temporary tables within a HiveContext. val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - rdd.registerAsTable("records") + rdd.registerTempTable("records") // Queries can then join RDD data with data stored in Hive. println("Result of SELECT *:") diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 36e50e49c9a9c..42b738e112809 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -909,7 +909,7 @@ def __init__(self, sparkContext, sqlContext=None): ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerAsTable("allTypes") + >>> srdd.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] @@ -1486,19 +1486,23 @@ def saveAsParquetFile(self, path): """ self._jschema_rdd.saveAsParquetFile(path) - def registerAsTable(self, name): + def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the L{SQLContext} that was used to create this SchemaRDD. >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerAsTable("test") + >>> srdd.registerTempTable("test") >>> srdd2 = sqlCtx.sql("select * from test") >>> sorted(srdd.collect()) == sorted(srdd2.collect()) True """ - self._jschema_rdd.registerAsTable(name) + self._jschema_rdd.registerTempTable(name) + + def registerAsTable(self, name): + warnings.warn("Use registerTempTable instead of registerAsTable.", DeprecationWarning) + self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): """Inserts the contents of this SchemaRDD into the specified table. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 33931e5d996f5..567f4dca991b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -116,7 +116,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * // |-- name: string (nullable = false) * // |-- age: integer (nullable = true) * - * peopleSchemaRDD.registerAsTable("people") + * peopleSchemaRDD.registerTempTable("people") * sqlContext.sql("select name from people").collect.foreach(println) * }}} * @@ -212,7 +212,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * import sqlContext._ * * case class Person(name: String, age: Int) - * createParquetFile[Person]("path/to/file.parquet").registerAsTable("people") + * createParquetFile[Person]("path/to/file.parquet").registerTempTable("people") * sql("INSERT INTO people SELECT 'michael', 29") * }}} * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index d34f62dc8865e..57df79321b35d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -67,7 +67,7 @@ import org.apache.spark.api.java.JavaRDD * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) * // Any RDD containing case classes can be registered as a table. The schema of the table is * // automatically inferred using scala reflection. - * rdd.registerAsTable("records") + * rdd.registerTempTable("records") * * val results: SchemaRDD = sql("SELECT * FROM records") * }}} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 6a20def475822..2f3033a5f94f0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -83,10 +83,13 @@ private[sql] trait SchemaRDDLike { * * @group schema */ - def registerAsTable(tableName: String): Unit = { + def registerTempTable(tableName: String): Unit = { sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) } + @deprecated("Use registerTempTable instead of registerAsTable.", "1.1") + def registerAsTable(tableName: String): Unit = registerTempTable(tableName) + /** * :: Experimental :: * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index ae45193ed15d3..dbaa16e8b0c68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -52,7 +52,7 @@ class JavaSQLContext(val sqlContext: SQLContext) extends UDFRegistration { * {{{ * JavaSQLContext sqlCtx = new JavaSQLContext(...) * - * sqlCtx.createParquetFile(Person.class, "path/to/file.parquet").registerAsTable("people") + * sqlCtx.createParquetFile(Person.class, "path/to/file.parquet").registerTempTable("people") * sqlCtx.sql("INSERT INTO people SELECT 'michael', 29") * }}} * diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 3c92906d82864..33e5020bc636a 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -98,7 +98,7 @@ public Row call(Person person) throws Exception { StructType schema = DataType.createStructType(fields); JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); - schemaRDD.registerAsTable("people"); + schemaRDD.registerTempTable("people"); List actual = javaSqlCtx.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); @@ -149,14 +149,14 @@ public void applySchemaToJSON() { JavaSchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD); StructType actualSchema1 = schemaRDD1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); - schemaRDD1.registerAsTable("jsonTable1"); + schemaRDD1.registerTempTable("jsonTable1"); List actual1 = javaSqlCtx.sql("select * from jsonTable1").collect(); Assert.assertEquals(expectedResult, actual1); JavaSchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); StructType actualSchema2 = schemaRDD2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - schemaRDD1.registerAsTable("jsonTable2"); + schemaRDD1.registerTempTable("jsonTable2"); List actual2 = javaSqlCtx.sql("select * from jsonTable2").collect(); Assert.assertEquals(expectedResult, actual2); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index c3c0dcb1aa00b..fbf9bd9dbcdea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -78,7 +78,7 @@ class CachedTableSuite extends QueryTest { } test("SELECT Star Cached Table") { - TestSQLContext.sql("SELECT * FROM testData").registerAsTable("selectStar") + TestSQLContext.sql("SELECT * FROM testData").registerTempTable("selectStar") TestSQLContext.cacheTable("selectStar") TestSQLContext.sql("SELECT * FROM selectStar WHERE key = 1").collect() TestSQLContext.uncacheTable("selectStar") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala index 23a711d08c58b..c87d762751e6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/InsertIntoSuite.scala @@ -31,7 +31,7 @@ class InsertIntoSuite extends QueryTest { testFilePath.delete() testFilePath.deleteOnExit() val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerAsTable("createAndInsertTest") + testFile.registerTempTable("createAndInsertTest") // Add some data. testData.insertInto("createAndInsertTest") @@ -86,7 +86,7 @@ class InsertIntoSuite extends QueryTest { testFilePath.delete() testFilePath.deleteOnExit() val testFile = createParquetFile[TestData](testFilePath.getCanonicalPath) - testFile.registerAsTable("createAndInsertSQLTest") + testFile.registerTempTable("createAndInsertSQLTest") sql("INSERT INTO createAndInsertSQLTest SELECT * FROM testData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 2fc80588182d9..6c7697ece8c56 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -285,8 +285,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("full outer join") { - upperCaseData.where('N <= 4).registerAsTable("left") - upperCaseData.where('N >= 3).registerAsTable("right") + upperCaseData.where('N <= 4).registerTempTable("left") + upperCaseData.where('N >= 3).registerTempTable("right") val left = UnresolvedRelation(None, "left", None) val right = UnresolvedRelation(None, "right", None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5c571d35d1bb9..9b2a36d33fca7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -461,7 +461,7 @@ class SQLQuerySuite extends QueryTest { } val schemaRDD1 = applySchema(rowRDD1, schema1) - schemaRDD1.registerAsTable("applySchema1") + schemaRDD1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), (1, "A1", true, null) :: @@ -491,7 +491,7 @@ class SQLQuerySuite extends QueryTest { } val schemaRDD2 = applySchema(rowRDD2, schema2) - schemaRDD2.registerAsTable("applySchema2") + schemaRDD2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), (Seq(1, true), Map("A1" -> null)) :: @@ -516,7 +516,7 @@ class SQLQuerySuite extends QueryTest { } val schemaRDD3 = applySchema(rowRDD3, schema2) - schemaRDD3.registerAsTable("applySchema3") + schemaRDD3.registerTempTable("applySchema3") checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index f2934da9a031d..5b84c658db942 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -61,7 +61,7 @@ class ScalaReflectionRelationSuite extends FunSuite { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, BigDecimal(1), new Timestamp(12345), Seq(1,2,3)) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectData") + rdd.registerTempTable("reflectData") assert(sql("SELECT * FROM reflectData").collect().head === data.productIterator.toSeq) } @@ -69,7 +69,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectNullData") + rdd.registerTempTable("reflectNullData") assert(sql("SELECT * FROM reflectNullData").collect().head === Seq.fill(7)(null)) } @@ -77,7 +77,7 @@ class ScalaReflectionRelationSuite extends FunSuite { test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) val rdd = sparkContext.parallelize(data :: Nil) - rdd.registerAsTable("reflectOptionalData") + rdd.registerTempTable("reflectOptionalData") assert(sql("SELECT * FROM reflectOptionalData").collect().head === Seq.fill(7)(null)) } @@ -85,7 +85,7 @@ class ScalaReflectionRelationSuite extends FunSuite { // Equality is broken for Arrays, so we test that separately. test("query binary data") { val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.registerAsTable("reflectBinary") + rdd.registerTempTable("reflectBinary") val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 088e6e3c843aa..c3ec82fb69778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -30,7 +30,7 @@ case class TestData(key: Int, value: String) object TestData { val testData: SchemaRDD = TestSQLContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerAsTable("testData") + testData.registerTempTable("testData") case class LargeAndSmallInts(a: Int, b: Int) val largeAndSmallInts: SchemaRDD = @@ -41,7 +41,7 @@ object TestData { LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: LargeAndSmallInts(3, 2) :: Nil) - largeAndSmallInts.registerAsTable("largeAndSmallInts") + largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = @@ -52,7 +52,7 @@ object TestData { TestData2(2, 2) :: TestData2(3, 1) :: TestData2(3, 2) :: Nil) - testData2.registerAsTable("testData2") + testData2.registerTempTable("testData2") // TODO: There is no way to express null primitives as case classes currently... val testData3 = @@ -71,7 +71,7 @@ object TestData { UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: UpperCaseData(6, "F") :: Nil) - upperCaseData.registerAsTable("upperCaseData") + upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) val lowerCaseData = @@ -80,14 +80,14 @@ object TestData { LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: LowerCaseData(4, "d") :: Nil) - lowerCaseData.registerAsTable("lowerCaseData") + lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) val arrayData = TestSQLContext.sparkContext.parallelize( ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) :: ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil) - arrayData.registerAsTable("arrayData") + arrayData.registerTempTable("arrayData") case class MapData(data: Map[Int, String]) val mapData = @@ -97,18 +97,18 @@ object TestData { MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: MapData(Map(1 -> "a4", 2 -> "b4")) :: MapData(Map(1 -> "a5")) :: Nil) - mapData.registerAsTable("mapData") + mapData.registerTempTable("mapData") case class StringData(s: String) val repeatedData = TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) - repeatedData.registerAsTable("repeatedData") + repeatedData.registerTempTable("repeatedData") val nullableRepeatedData = TestSQLContext.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) - nullableRepeatedData.registerAsTable("nullableRepeatedData") + nullableRepeatedData.registerTempTable("nullableRepeatedData") case class NullInts(a: Integer) val nullInts = @@ -118,7 +118,7 @@ object TestData { NullInts(3) :: NullInts(null) :: Nil ) - nullInts.registerAsTable("nullInts") + nullInts.registerTempTable("nullInts") val allNulls = TestSQLContext.sparkContext.parallelize( @@ -126,7 +126,7 @@ object TestData { NullInts(null) :: NullInts(null) :: NullInts(null) :: Nil) - allNulls.registerAsTable("allNulls") + allNulls.registerTempTable("allNulls") case class NullStrings(n: Int, s: String) val nullStrings = @@ -134,10 +134,10 @@ object TestData { NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil) - nullStrings.registerAsTable("nullStrings") + nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String) - TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerAsTable("tableName") + TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName") val unparsedStrings = TestSQLContext.sparkContext.parallelize( @@ -150,5 +150,5 @@ object TestData { val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => TimestampField(new Timestamp(i)) }) - timestamps.registerAsTable("timestamps") + timestamps.registerTempTable("timestamps") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 020baf0c7ec6f..203ff847e94cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -59,7 +59,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(person :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[PersonBean]) - schemaRDD.registerAsTable("people") + schemaRDD.registerTempTable("people") javaSqlCtx.sql("SELECT * FROM people").collect() } @@ -76,7 +76,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) - schemaRDD.registerAsTable("allTypes") + schemaRDD.registerTempTable("allTypes") assert( javaSqlCtx.sql( @@ -101,7 +101,7 @@ class JavaSQLSuite extends FunSuite { val rdd = javaCtx.parallelize(bean :: Nil) val schemaRDD = javaSqlCtx.applySchema(rdd, classOf[AllTypesBean]) - schemaRDD.registerAsTable("allTypes") + schemaRDD.registerTempTable("allTypes") assert( javaSqlCtx.sql( @@ -127,7 +127,7 @@ class JavaSQLSuite extends FunSuite { var schemaRDD = javaSqlCtx.jsonRDD(rdd) - schemaRDD.registerAsTable("jsonTable1") + schemaRDD.registerTempTable("jsonTable1") assert( javaSqlCtx.sql("select * from jsonTable1").collect.head.row === @@ -144,7 +144,7 @@ class JavaSQLSuite extends FunSuite { rdd.saveAsTextFile(path) schemaRDD = javaSqlCtx.jsonFile(path) - schemaRDD.registerAsTable("jsonTable2") + schemaRDD.registerTempTable("jsonTable2") assert( javaSqlCtx.sql("select * from jsonTable2").collect.head.row === diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 9d9cfdd7c92e3..75c0589eb208e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -183,7 +183,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -223,7 +223,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Access elements of a primitive array. checkAnswer( @@ -291,7 +291,7 @@ class JsonSuite extends QueryTest { ignore("Complex field and type inferring (Ignored)") { val jsonSchemaRDD = jsonRDD(complexFieldAndType) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( @@ -320,7 +320,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -374,7 +374,7 @@ class JsonSuite extends QueryTest { ignore("Type conflict in primitive field values (Ignored)") { val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expreesion. // Number and Boolean conflict: resolve the type as boolean in this query. @@ -445,7 +445,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -466,7 +466,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -494,7 +494,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") } test("Loading a JSON dataset from a text file") { @@ -514,7 +514,7 @@ class JsonSuite extends QueryTest { assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerAsTable("jsonTable") + jsonSchemaRDD.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -546,7 +546,7 @@ class JsonSuite extends QueryTest { assert(schema === jsonSchemaRDD1.schema) - jsonSchemaRDD1.registerAsTable("jsonTable1") + jsonSchemaRDD1.registerTempTable("jsonTable1") checkAnswer( sql("select * from jsonTable1"), @@ -563,7 +563,7 @@ class JsonSuite extends QueryTest { assert(schema === jsonSchemaRDD2.schema) - jsonSchemaRDD2.registerAsTable("jsonTable2") + jsonSchemaRDD2.registerTempTable("jsonTable2") checkAnswer( sql("select * from jsonTable2"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 8955455ec98c7..9933575038bd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -101,9 +101,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA ParquetTestData.writeNestedFile3() ParquetTestData.writeNestedFile4() testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerAsTable("testsource") + testRDD.registerTempTable("testsource") parquetFile(ParquetTestData.testFilterDir.toString) - .registerAsTable("testfiltersource") + .registerTempTable("testfiltersource") } override def afterAll() { @@ -247,7 +247,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA test("Creating case class RDD table") { TestSQLContext.sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) - .registerAsTable("tmp") + .registerTempTable("tmp") val rdd = sql("SELECT * FROM tmp").collect().sortBy(_.getInt(0)) var counter = 1 rdd.foreach { @@ -266,7 +266,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA .map(i => TestRDDEntry(i, s"val_$i")) rdd.saveAsParquetFile(path) val readFile = parquetFile(path) - readFile.registerAsTable("tmpx") + readFile.registerTempTable("tmpx") val rdd_copy = sql("SELECT * FROM tmpx").collect() val rdd_orig = rdd.collect() for(i <- 0 to 99) { @@ -280,9 +280,9 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val dirname = Utils.createTempDir() val source_rdd = TestSQLContext.sparkContext.parallelize((1 to 100)) .map(i => TestRDDEntry(i, s"val_$i")) - source_rdd.registerAsTable("source") + source_rdd.registerTempTable("source") val dest_rdd = createParquetFile[TestRDDEntry](dirname.toString) - dest_rdd.registerAsTable("dest") + dest_rdd.registerTempTable("dest") sql("INSERT OVERWRITE INTO dest SELECT * FROM source").collect() val rdd_copy1 = sql("SELECT * FROM dest").collect() assert(rdd_copy1.size === 100) @@ -547,7 +547,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir1.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val query = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM data") val tmp = query.collect() assert(tmp.size === 2) @@ -562,7 +562,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir2.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val result1 = nestedParserSqlContext.sql("SELECT entries[0].value FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) @@ -589,7 +589,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir3.toString) .toSchemaRDD - data.registerAsTable("data") + data.registerTempTable("data") val result1 = nestedParserSqlContext.sql("SELECT booleanNumberPairs[0].value[0].truth FROM data").collect() assert(result1.size === 1) assert(result1(0).size === 1) @@ -608,7 +608,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = TestSQLContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD - data.registerAsTable("mapTable") + data.registerTempTable("mapTable") val result1 = sql("SELECT data1 FROM mapTable").collect() assert(result1.size === 1) assert(result1(0)(0) @@ -625,7 +625,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val data = nestedParserSqlContext .parquetFile(ParquetTestData.testNestedDir4.toString) .toSchemaRDD - data.registerAsTable("mapTable") + data.registerTempTable("mapTable") val result1 = nestedParserSqlContext.sql("SELECT data2 FROM mapTable").collect() assert(result1.size === 1) val entry1 = result1(0)(0) @@ -658,7 +658,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD - .registerAsTable("tmpcopy") + .registerTempTable("tmpcopy") val tmpdata = nestedParserSqlContext.sql("SELECT owner, contacts[1].name FROM tmpcopy").collect() assert(tmpdata.size === 2) assert(tmpdata(0).size === 2) @@ -679,7 +679,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA nestedParserSqlContext .parquetFile(tmpdir.toString) .toSchemaRDD - .registerAsTable("tmpmapcopy") + .registerTempTable("tmpmapcopy") val result1 = nestedParserSqlContext.sql("""SELECT data1["key2"] FROM tmpmapcopy""").collect() assert(result1.size === 1) assert(result1(0)(0) === 2) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 833f3502154f3..7e323146f9da2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -28,7 +28,7 @@ case class TestData(key: Int, value: String) class InsertIntoHiveTableSuite extends QueryTest { val testData = TestHive.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))) - testData.registerAsTable("testData") + testData.registerTempTable("testData") test("insertInto() HiveTable") { createTable[TestData]("createAndInsertTest") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala index 10c8069a624e6..578f27574ad2f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/api/java/JavaHiveQLSuite.scala @@ -63,7 +63,7 @@ class JavaHiveQLSuite extends FunSuite { javaHiveCtx.hql(s"CREATE TABLE $tableName(key INT, value STRING)").count() } - javaHiveCtx.hql("SHOW TABLES").registerAsTable("show_tables") + javaHiveCtx.hql("SHOW TABLES").registerTempTable("show_tables") assert( javaHiveCtx @@ -73,7 +73,7 @@ class JavaHiveQLSuite extends FunSuite { .contains(tableName)) assertResult(Array(Array("key", "int", "None"), Array("value", "string", "None"))) { - javaHiveCtx.hql(s"DESCRIBE $tableName").registerAsTable("describe_table") + javaHiveCtx.hql(s"DESCRIBE $tableName").registerTempTable("describe_table") javaHiveCtx .hql("SELECT result FROM describe_table") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 89cc589fb8001..4ed41550cf530 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -247,7 +247,7 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(2, "str2") :: Nil) - testData.registerAsTable("REGisteredTABle") + testData.registerTempTable("REGisteredTABle") assertResult(Array(Array(2, "str2"))) { hql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " + @@ -272,7 +272,7 @@ class HiveQuerySuite extends HiveComparisonTest { test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") { val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3)) .zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)} - TestHive.sparkContext.parallelize(fixture).registerAsTable("having_test") + TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test") val results = hql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3") .collect() @@ -401,7 +401,7 @@ class HiveQuerySuite extends HiveComparisonTest { TestHive.sparkContext.parallelize( TestData(1, "str1") :: TestData(1, "str2") :: Nil) - testData.registerAsTable("test_describe_commands2") + testData.registerTempTable("test_describe_commands2") assertResult( Array( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index fb03db12a0b01..2455c18925dfa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -54,14 +54,14 @@ class HiveResolutionSuite extends HiveComparisonTest { test("case insensitivity with scala reflection") { // Test resolution with Scala Reflection TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerAsTable("caseSensitivityTest") + .registerTempTable("caseSensitivityTest") hql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest") } test("nested repeated resolution") { TestHive.sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil) - .registerAsTable("nestedRepeatedTest") + .registerTempTable("nestedRepeatedTest") assert(hql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala index 47526e3596e44..6545e8d7dcb69 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/HiveParquetSuite.scala @@ -41,7 +41,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft // write test data ParquetTestData.writeFile() testRDD = parquetFile(ParquetTestData.testDir.toString) - testRDD.registerAsTable("testsource") + testRDD.registerTempTable("testsource") } override def afterAll() { @@ -67,7 +67,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft .map(i => Cases(i, i)) .saveAsParquetFile(tempFile.getCanonicalPath) - parquetFile(tempFile.getCanonicalPath).registerAsTable("cases") + parquetFile(tempFile.getCanonicalPath).registerTempTable("cases") hql("SELECT upper FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) hql("SELECT LOWER FROM cases").collect().map(_.getString(0)) === (1 to 10).map(_.toString) } @@ -86,7 +86,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("Converting Hive to Parquet Table via saveAsParquetFile") { hql("SELECT * FROM src").saveAsParquetFile(dirname.getAbsolutePath) - parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") + parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") val rddOne = hql("SELECT * FROM src").collect().sortBy(_.getInt(0)) val rddTwo = hql("SELECT * from ptable").collect().sortBy(_.getInt(0)) compareRDDs(rddOne, rddTwo, "src (Hive)", Seq("key:Int", "value:String")) @@ -94,7 +94,7 @@ class HiveParquetSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAft test("INSERT OVERWRITE TABLE Parquet table") { hql("SELECT * FROM testsource").saveAsParquetFile(dirname.getAbsolutePath) - parquetFile(dirname.getAbsolutePath).registerAsTable("ptable") + parquetFile(dirname.getAbsolutePath).registerTempTable("ptable") // let's do three overwrites for good measure hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() hql("INSERT OVERWRITE TABLE ptable SELECT * FROM testsource").collect() From 33f167d762483b55d5d874dcc1e3075f661d4375 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 2 Aug 2014 21:44:19 -0700 Subject: [PATCH 11/13] SPARK-2602 [BUILD] Tests steal focus under Java 6 As per https://issues.apache.org/jira/browse/SPARK-2602 , this may be resolved for Java 6 with the java.awt.headless system property, which never hurt anyone running a command line app. I tested it and seemed to get rid of focus stealing. Author: Sean Owen Closes #1747 from srowen/SPARK-2602 and squashes the following commits: b141018 [Sean Owen] Set java.awt.headless during tests --- pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/pom.xml b/pom.xml index a42759169149b..cc9377cec2a07 100644 --- a/pom.xml +++ b/pom.xml @@ -871,6 +871,7 @@ -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m + true ${session.executionRootDirectory} 1 From 9cf429aaf529e91f619910c33cfe46bf33a66982 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Sat, 2 Aug 2014 21:55:56 -0700 Subject: [PATCH 12/13] SPARK-2414 [BUILD] Add LICENSE entry for jquery The JIRA concerned removing jquery, and this does not remove jquery. While it is distributed by Spark it should have an accompanying line in LICENSE, very technically, as per http://www.apache.org/dev/licensing-howto.html Author: Sean Owen Closes #1748 from srowen/SPARK-2414 and squashes the following commits: 2fdb03c [Sean Owen] Add LICENSE entry for jquery --- LICENSE | 1 + 1 file changed, 1 insertion(+) diff --git a/LICENSE b/LICENSE index 76a3601c66918..e9a1153fdc5db 100644 --- a/LICENSE +++ b/LICENSE @@ -549,3 +549,4 @@ The following components are provided under the MIT License. See project link fo (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (MIT License) jquery (https://jquery.org/license/) From 3dc55fdf450b4237f7c592fce56d1467fd206366 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 2 Aug 2014 22:00:46 -0700 Subject: [PATCH 13/13] [Minor] Fixes on top of #1679 Minor fixes on top of #1679. Author: Andrew Or Closes #1736 from andrewor14/amend-#1679 and squashes the following commits: 3b46f5e [Andrew Or] Minor fixes --- .../org/apache/spark/storage/BlockManagerSource.scala | 5 ++--- .../scala/org/apache/spark/storage/StorageUtils.scala | 11 ++++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala index e939318a029dd..3f14c40ec61cb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSource.scala @@ -46,9 +46,8 @@ private[spark] class BlockManagerSource(val blockManager: BlockManager, sc: Spar metricRegistry.register(MetricRegistry.name("memory", "memUsed_MB"), new Gauge[Long] { override def getValue: Long = { val storageStatusList = blockManager.master.getStorageStatus - val maxMem = storageStatusList.map(_.maxMem).sum - val remainingMem = storageStatusList.map(_.memRemaining).sum - (maxMem - remainingMem) / 1024 / 1024 + val memUsed = storageStatusList.map(_.memUsed).sum + memUsed / 1024 / 1024 } }) diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 0a0a448baa2ef..2bd6b749be261 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -172,16 +172,13 @@ class StorageStatus(val blockManagerId: BlockManagerId, val maxMem: Long) { def memRemaining: Long = maxMem - memUsed /** Return the memory used by this block manager. */ - def memUsed: Long = - _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum + def memUsed: Long = _nonRddStorageInfo._1 + _rddBlocks.keys.toSeq.map(memUsedByRdd).sum /** Return the disk space used by this block manager. */ - def diskUsed: Long = - _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum + def diskUsed: Long = _nonRddStorageInfo._2 + _rddBlocks.keys.toSeq.map(diskUsedByRdd).sum /** Return the off-heap space used by this block manager. */ - def offHeapUsed: Long = - _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum + def offHeapUsed: Long = _nonRddStorageInfo._3 + _rddBlocks.keys.toSeq.map(offHeapUsedByRdd).sum /** Return the memory used by the given RDD in this block manager in O(1) time. */ def memUsedByRdd(rddId: Int): Long = _rddStorageInfo.get(rddId).map(_._1).getOrElse(0L) @@ -246,7 +243,7 @@ private[spark] object StorageUtils { val rddId = rddInfo.id // Assume all blocks belonging to the same RDD have the same storage level val storageLevel = statuses - .map(_.rddStorageLevel(rddId)).flatMap(s => s).headOption.getOrElse(StorageLevel.NONE) + .flatMap(_.rddStorageLevel(rddId)).headOption.getOrElse(StorageLevel.NONE) val numCachedPartitions = statuses.map(_.numRddBlocksById(rddId)).sum val memSize = statuses.map(_.memUsedByRdd(rddId)).sum val diskSize = statuses.map(_.diskUsedByRdd(rddId)).sum