From 28ccf5ee769a1df019e38985112065c01724fbd9 Mon Sep 17 00:00:00 2001 From: Alexander Ulanov Date: Mon, 23 Feb 2015 12:09:40 -0800 Subject: [PATCH 01/28] [MLLIB] SPARK-5912 Programming guide for feature selection Added description of ChiSqSelector and few words about feature selection in general. I could add a code example, however it would not look reasonable in the absence of feature discretizer or a dataset in the `data` folder that has redundant features. Author: Alexander Ulanov Closes #4709 from avulanov/SPARK-5912 and squashes the following commits: 19a8a4e [Alexander Ulanov] Addressing reviewers comments @jkbradley 58d9e4d [Alexander Ulanov] Addressing reviewers comments @jkbradley eb6b9fe [Alexander Ulanov] Typo 2921a1d [Alexander Ulanov] ChiSqSelector example of use c845350 [Alexander Ulanov] ChiSqSelector docs --- docs/mllib-feature-extraction.md | 54 ++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index d4a61a7fbf3d7..d588b9cb46697 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -375,3 +375,57 @@ data2 = labels.zip(normalizer2.transform(features)) {% endhighlight %} + +## Feature selection +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. The number of features to select can be determined using the validation set. Feature selection is usually applied on sparse data, for example in text classification. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. + +### ChiSqSelector +ChiSqSelector stands for Chi-Squared feature selection. It operates on the labeled data. ChiSqSelector orders categorical features based on their values of Chi-Squared test on independence from class and filters (selects) top given features. + +#### Model Fitting + +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the +following parameters in the constructor: + +* `numTopFeatures` number of top features that selector will select (filter). + +We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in +`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then +return a model which can transform the input dataset into the reduced feature space. + +This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) +which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on +an `RDD[Vector]` to produce a reduced `RDD[Vector]`. + +Note that the model that performs actual feature filtering can be instantiated independently with array of feature indices that has to be sorted ascending. + +#### Example + +The following example shows the basic use of ChiSqSelector. + +
+
+{% highlight scala %} +import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLUtils + +// load some data in libsvm format, each point is in the range 0..255 +val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") +// discretize data in 16 equal bins +val discretizedData = data.map { lp => + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) +} +// create ChiSqSelector that will select 50 features +val selector = new ChiSqSelector(50) +// create ChiSqSelector model +val transformer = selector.fit(disctetizedData) +// filter top 50 features from each feature vector +val filteredData = disctetizedData.map { lp => + LabeledPoint(lp.label, transformer.transform(lp.features)) +} +{% endhighlight %} +
+
+ From 59536cc87e10e5011560556729dd901280958f43 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Mon, 23 Feb 2015 16:15:57 -0800 Subject: [PATCH 02/28] [SPARK-5912] [docs] [mllib] Small fixes to ChiSqSelector docs Fixes: * typo in Scala example * Removed comment "usually applied on sparse data" since that is debatable * small edits to text for clarity CC: avulanov I noticed a typo post-hoc and ended up making a few small edits. Do the changes look OK? Author: Joseph K. Bradley Closes #4732 from jkbradley/chisqselector-docs and squashes the following commits: 9656a3b [Joseph K. Bradley] added Java example for ChiSqSelector to guide 3f3f9f4 [Joseph K. Bradley] small fixes to ChiSqSelector docs --- docs/mllib-feature-extraction.md | 72 ++++++++++++++++++++++++++------ 1 file changed, 60 insertions(+), 12 deletions(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index d588b9cb46697..80842b27effd8 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -377,27 +377,27 @@ data2 = labels.zip(normalizer2.transform(features)) ## Feature selection -[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. The number of features to select can be determined using the validation set. Feature selection is usually applied on sparse data, for example in text classification. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. +[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. ### ChiSqSelector -ChiSqSelector stands for Chi-Squared feature selection. It operates on the labeled data. ChiSqSelector orders categorical features based on their values of Chi-Squared test on independence from class and filters (selects) top given features. +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. #### Model Fitting [`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the following parameters in the constructor: -* `numTopFeatures` number of top features that selector will select (filter). +* `numTopFeatures` number of top features that the selector will select (filter). We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in `ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then -return a model which can transform the input dataset into the reduced feature space. +return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space. This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on an `RDD[Vector]` to produce a reduced `RDD[Vector]`. -Note that the model that performs actual feature filtering can be instantiated independently with array of feature indices that has to be sorted ascending. +Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order). #### Example @@ -411,21 +411,69 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils -// load some data in libsvm format, each point is in the range 0..255 +// Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") -// discretize data in 16 equal bins +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features val discretizedData = data.map { lp => LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) } -// create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select 50 features val selector = new ChiSqSelector(50) -// create ChiSqSelector model -val transformer = selector.fit(disctetizedData) -// filter top 50 features from each feature vector -val filteredData = disctetizedData.map { lp => +// Create ChiSqSelector model (selecting features) +val transformer = selector.fit(discretizedData) +// Filter the top 50 features from each feature vector +val filteredData = discretizedData.map { lp => LabeledPoint(lp.label, transformer.transform(lp.features)) } {% endhighlight %} + +
+{% highlight java %} +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.feature.ChiSqSelector; +import org.apache.spark.mllib.feature.ChiSqSelectorModel; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.mllib.regression.LabeledPoint; +import org.apache.spark.mllib.util.MLUtils; + +SparkConf sparkConf = new SparkConf().setAppName("JavaChiSqSelector"); +JavaSparkContext sc = new JavaSparkContext(sparkConf); +JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), + "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); + +// Discretize data in 16 equal bins since ChiSqSelector requires categorical features +JavaRDD discretizedData = points.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + final double[] discretizedFeatures = new double[lp.features().size()]; + for (int i = 0; i < lp.features().size(); ++i) { + discretizedFeatures[i] = lp.features().apply(i) / 16; + } + return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); + } + }); + +// Create ChiSqSelector that will select 50 features +ChiSqSelector selector = new ChiSqSelector(50); +// Create ChiSqSelector model (selecting features) +final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); +// Filter the top 50 features from each feature vector +JavaRDD filteredData = discretizedData.map( + new Function() { + @Override + public LabeledPoint call(LabeledPoint lp) { + return new LabeledPoint(lp.label(), transformer.transform(lp.features())); + } + } +); + +sc.stop(); +{% endhighlight %} +
From 48376bfe9c97bf31279918def6c6615849c88f4d Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 23 Feb 2015 17:16:34 -0800 Subject: [PATCH 03/28] [SPARK-5935][SQL] Accept MapType in the schema provided to a JSON dataset. JIRA: https://issues.apache.org/jira/browse/SPARK-5935 Author: Yin Huai Author: Yin Huai Closes #4710 from yhuai/jsonMapType and squashes the following commits: 3e40390 [Yin Huai] Remove unnecessary changes. f8e6267 [Yin Huai] Fix test. baa36e3 [Yin Huai] Accept MapType in the schema provided to jsonFile/jsonRDD. --- .../org/apache/spark/sql/json/JsonRDD.scala | 3 + .../org/apache/spark/sql/json/JsonSuite.scala | 56 +++++++++++++++++++ .../apache/spark/sql/json/TestJsonData.scala | 17 ++++++ 3 files changed, 76 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 3b8dde1823370..d83bdc2f7ff9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -416,6 +416,9 @@ private[sql] object JsonRDD extends Logging { case NullType => null case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) + case MapType(StringType, valueType, _) => + val map = value.asInstanceOf[Map[String, Any]] + map.mapValues(enforceCorrectType(_, valueType)).map(identity) case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct) case DateType => toDate(value) case TimestampType => toTimestamp(value) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index c94e44bd7c397..005f20b96df79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -657,6 +657,62 @@ class JsonSuite extends QueryTest { ) } + test("Applying schemas with MapType") { + val schemaWithSimpleMap = StructType( + StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + val jsonWithSimpleMap = jsonRDD(mapType1, schemaWithSimpleMap) + + jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap") + + checkAnswer( + sql("select map from jsonWithSimpleMap"), + Row(Map("a" -> 1)) :: + Row(Map("b" -> 2)) :: + Row(Map("c" -> 3)) :: + Row(Map("c" -> 1, "d" -> 4)) :: + Row(Map("e" -> null)) :: Nil + ) + + checkAnswer( + sql("select map['c'] from jsonWithSimpleMap"), + Row(null) :: + Row(null) :: + Row(3) :: + Row(1) :: + Row(null) :: Nil + ) + + val innerStruct = StructType( + StructField("field1", ArrayType(IntegerType, true), true) :: + StructField("field2", IntegerType, true) :: Nil) + val schemaWithComplexMap = StructType( + StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) + + val jsonWithComplexMap = jsonRDD(mapType2, schemaWithComplexMap) + + jsonWithComplexMap.registerTempTable("jsonWithComplexMap") + + checkAnswer( + sql("select map from jsonWithComplexMap"), + Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) :: + Row(Map("b" -> Row(null, 2))) :: + Row(Map("c" -> Row(Seq(), 4))) :: + Row(Map("c" -> Row(null, 3), "d" -> Row(Seq(null), null))) :: + Row(Map("e" -> null)) :: + Row(Map("f" -> Row(null, null))) :: Nil + ) + + checkAnswer( + sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"), + Row(Seq(1, 2, 3, null), null) :: + Row(null, null) :: + Row(null, 4) :: + Row(null, 3) :: + Row(null, null) :: + Row(null, null) :: Nil + ) + } + test("SPARK-2096 Correctly parse dot notations") { val jsonDF = jsonRDD(complexFieldAndType2) jsonDF.registerTempTable("jsonTable") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala index 3370b3c98b4be..15698f61e0837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala @@ -146,6 +146,23 @@ object TestJsonData { ]] }""" :: Nil) + val mapType1 = + TestSQLContext.sparkContext.parallelize( + """{"map": {"a": 1}}""" :: + """{"map": {"b": 2}}""" :: + """{"map": {"c": 3}}""" :: + """{"map": {"c": 1, "d": 4}}""" :: + """{"map": {"e": null}}""" :: Nil) + + val mapType2 = + TestSQLContext.sparkContext.parallelize( + """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" :: + """{"map": {"b": {"field2": 2}}}""" :: + """{"map": {"c": {"field1": [], "field2": 4}}}""" :: + """{"map": {"c": {"field2": 3}, "d": {"field1": [null]}}}""" :: + """{"map": {"e": null}}""" :: + """{"map": {"f": {"field1": null}}}""" :: Nil) + val nullsInArrays = TestSQLContext.sparkContext.parallelize( """{"field1":[[null], [[["Test"]]]]}""" :: From 1ed57086d402c38d95cda6c3d9d7aea806609bf9 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Mon, 23 Feb 2015 17:34:54 -0800 Subject: [PATCH 04/28] [SPARK-5873][SQL] Allow viewing of partially analyzed plans in queryExecution Author: Michael Armbrust Closes #4684 from marmbrus/explainAnalysis and squashes the following commits: afbaa19 [Michael Armbrust] fix python d93278c [Michael Armbrust] fix hive e5fa0a4 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis 52119f2 [Michael Armbrust] more tests 82a5431 [Michael Armbrust] fix tests 25753d2 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into explainAnalysis aee1e6a [Michael Armbrust] fix hive b23a844 [Michael Armbrust] newline de8dc51 [Michael Armbrust] more comments acf620a [Michael Armbrust] [SPARK-5873][SQL] Show partially analyzed plans in query execution --- python/pyspark/sql/context.py | 30 ++--- .../apache/spark/sql/catalyst/SqlParser.scala | 2 + .../sql/catalyst/analysis/Analyzer.scala | 83 -------------- .../sql/catalyst/analysis/CheckAnalysis.scala | 105 ++++++++++++++++++ .../sql/catalyst/analysis/AnalysisSuite.scala | 35 +++--- .../org/apache/spark/sql/DataFrame.scala | 2 +- .../scala/org/apache/spark/sql/SQLConf.scala | 5 +- .../org/apache/spark/sql/SQLContext.scala | 14 ++- .../org/apache/spark/sql/sources/rules.scala | 10 +- .../spark/sql/sources/DataSourceTest.scala | 1 - .../spark/sql/sources/InsertSuite.scala | 2 +- .../apache/spark/sql/hive/HiveContext.scala | 1 - 12 files changed, 164 insertions(+), 126 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 313f15e6d9b6f..125933c9d3ae0 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -267,20 +267,20 @@ def applySchema(self, rdd, schema): ... StructField("byte2", ByteType(), False), ... StructField("short1", ShortType(), False), ... StructField("short2", ShortType(), False), - ... StructField("int", IntegerType(), False), - ... StructField("float", FloatType(), False), - ... StructField("date", DateType(), False), - ... StructField("time", TimestampType(), False), - ... StructField("map", + ... StructField("int1", IntegerType(), False), + ... StructField("float1", FloatType(), False), + ... StructField("date1", DateType(), False), + ... StructField("time1", TimestampType(), False), + ... StructField("map1", ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct", + ... StructField("struct1", ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list", ArrayType(ByteType(), False), False), - ... StructField("null", DoubleType(), True)]) + ... StructField("list1", ArrayType(ByteType(), False), False), + ... StructField("null1", DoubleType(), True)]) >>> df = sqlCtx.applySchema(rdd, schema) >>> results = df.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, - ... x.time, x.map["a"], x.struct.b, x.list, x.null)) + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, + ... x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) @@ -288,20 +288,20 @@ def applySchema(self, rdd, schema): >>> df.registerTempTable("table2") >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + - ... "float + 1.5 as float FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)] + ... "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + ... "float1 + 1.5 as float1 FROM table2").collect() + [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int1=2147483646, float1=2.5)] >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type >>> rdd = sc.parallelize([(127, -32768, 1.0, ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte short float time map{} struct(b) list[]" + >>> abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" >>> schema = _parse_schema_abstract(abstract) >>> typedSchema = _infer_schema_type(rdd.first(), schema) >>> df = sqlCtx.applySchema(rdd, typedSchema) >>> df.collect() - [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] + [Row(byte1=127, short1=-32768, float1=1.0, time1=..., list1=[1, 2, 3])] """ if isinstance(rdd, DataFrame): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 124f083669358..b16aff99af1c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -78,6 +78,7 @@ class SqlParser extends AbstractSparkSQLParser { protected val IF = Keyword("IF") protected val IN = Keyword("IN") protected val INNER = Keyword("INNER") + protected val INT = Keyword("INT") protected val INSERT = Keyword("INSERT") protected val INTERSECT = Keyword("INTERSECT") protected val INTO = Keyword("INTO") @@ -394,6 +395,7 @@ class SqlParser extends AbstractSparkSQLParser { | fixedDecimalType | DECIMAL ^^^ DecimalType.Unlimited | DATE ^^^ DateType + | INT ^^^ IntegerType ) protected lazy val fixedDecimalType: Parser[DataType] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fc37b8cde0806..e4e542562f22d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -52,12 +52,6 @@ class Analyzer(catalog: Catalog, */ val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil - /** - * Override to provide additional rules for the "Check Analysis" batch. - * These rules will be evaluated after our built-in check rules. - */ - val extendedCheckRules: Seq[Rule[LogicalPlan]] = Nil - lazy val batches: Seq[Batch] = Seq( Batch("Resolution", fixedPoint, ResolveRelations :: @@ -71,87 +65,10 @@ class Analyzer(catalog: Catalog, TrimGroupingAliases :: typeCoercionRules ++ extendedResolutionRules : _*), - Batch("Check Analysis", Once, - CheckResolution +: - extendedCheckRules: _*), Batch("Remove SubQueries", fixedPoint, EliminateSubQueries) ) - /** - * Makes sure all attributes and logical plans have been resolved. - */ - object CheckResolution extends Rule[LogicalPlan] { - def failAnalysis(msg: String) = { throw new AnalysisException(msg) } - - def apply(plan: LogicalPlan): LogicalPlan = { - // We transform up and order the rules so as to catch the first possible failure instead - // of the result of cascading resolution failures. - plan.foreachUp { - case operator: LogicalPlan => - operator transformExpressionsUp { - case a: Attribute if !a.resolved => - val from = operator.inputSet.map(_.name).mkString(", ") - a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") - - case c: Cast if !c.resolved => - failAnalysis( - s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - - case b: BinaryExpression if !b.resolved => - failAnalysis( - s"invalid expression ${b.prettyString} " + - s"between ${b.left.simpleString} and ${b.right.simpleString}") - } - - operator match { - case f: Filter if f.condition.dataType != BooleanType => - failAnalysis( - s"filter expression '${f.condition.prettyString}' " + - s"of type ${f.condition.dataType.simpleString} is not a boolean.") - - case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) => - def checkValidAggregateExpression(expr: Expression): Unit = expr match { - case _: AggregateExpression => // OK - case e: Attribute if !groupingExprs.contains(e) => - failAnalysis( - s"expression '${e.prettyString}' is neither present in the group by, " + - s"nor is it an aggregate function. " + - "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.contains(e) => // OK - case e if e.references.isEmpty => // OK - case e => e.children.foreach(checkValidAggregateExpression) - } - - val cleaned = aggregateExprs.map(_.transform { - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - case Alias(g, _) => g - }) - - cleaned.foreach(checkValidAggregateExpression) - - case o if o.children.nonEmpty && - !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) => - val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",") - val input = o.inputSet.map(_.prettyString).mkString(",") - - failAnalysis(s"resolved attributes $missingAttributes missing from $input") - - // Catch all - case o if !o.resolved => - failAnalysis( - s"unresolved operator ${operator.simpleString}") - - case _ => // Analysis successful! - } - } - - plan - } - } - /** * Removes no-op Alias expressions from the plan. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala new file mode 100644 index 0000000000000..4e8fc892f3eea --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ + +/** + * Throws user facing errors when passed invalid queries that fail to analyze. + */ +class CheckAnalysis { + + /** + * Override to provide additional checks for correct analysis. + * These rules will be evaluated after our built-in check rules. + */ + val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil + + def failAnalysis(msg: String) = { + throw new AnalysisException(msg) + } + + def apply(plan: LogicalPlan): Unit = { + // We transform up and order the rules so as to catch the first possible failure instead + // of the result of cascading resolution failures. + plan.foreachUp { + case operator: LogicalPlan => + operator transformExpressionsUp { + case a: Attribute if !a.resolved => + val from = operator.inputSet.map(_.name).mkString(", ") + a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + + case c: Cast if !c.resolved => + failAnalysis( + s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") + + case b: BinaryExpression if !b.resolved => + failAnalysis( + s"invalid expression ${b.prettyString} " + + s"between ${b.left.simpleString} and ${b.right.simpleString}") + } + + operator match { + case f: Filter if f.condition.dataType != BooleanType => + failAnalysis( + s"filter expression '${f.condition.prettyString}' " + + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + + case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) => + def checkValidAggregateExpression(expr: Expression): Unit = expr match { + case _: AggregateExpression => // OK + case e: Attribute if !groupingExprs.contains(e) => + failAnalysis( + s"expression '${e.prettyString}' is neither present in the group by, " + + s"nor is it an aggregate function. " + + "Add to group by or wrap in first() if you don't care which value you get.") + case e if groupingExprs.contains(e) => // OK + case e if e.references.isEmpty => // OK + case e => e.children.foreach(checkValidAggregateExpression) + } + + val cleaned = aggregateExprs.map(_.transform { + // Should trim aliases around `GetField`s. These aliases are introduced while + // resolving struct field accesses, because `GetField` is not a `NamedExpression`. + // (Should we just turn `GetField` into a `NamedExpression`?) + case Alias(g, _) => g + }) + + cleaned.foreach(checkValidAggregateExpression) + + case o if o.children.nonEmpty && + !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) => + val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",") + val input = o.inputSet.map(_.prettyString).mkString(",") + + failAnalysis(s"resolved attributes $missingAttributes missing from $input") + + // Catch all + case o if !o.resolved => + failAnalysis( + s"unresolved operator ${operator.simpleString}") + + case _ => // Analysis successful! + } + } + extendedCheckRules.foreach(_(plan)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index aec7847356cd4..c1dd5aa913ddc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -30,11 +30,21 @@ import org.apache.spark.sql.catalyst.dsl.plans._ class AnalysisSuite extends FunSuite with BeforeAndAfter { val caseSensitiveCatalog = new SimpleCatalog(true) val caseInsensitiveCatalog = new SimpleCatalog(false) - val caseSensitiveAnalyze = + + val caseSensitiveAnalyzer = new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true) - val caseInsensitiveAnalyze = + val caseInsensitiveAnalyzer = new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false) + val checkAnalysis = new CheckAnalysis + + + def caseSensitiveAnalyze(plan: LogicalPlan) = + checkAnalysis(caseSensitiveAnalyzer(plan)) + + def caseInsensitiveAnalyze(plan: LogicalPlan) = + checkAnalysis(caseInsensitiveAnalyzer(plan)) + val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) val testRelation2 = LocalRelation( AttributeReference("a", StringType)(), @@ -55,7 +65,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) } - assert(caseInsensitiveAnalyze(plan).resolved) + assert(caseInsensitiveAnalyzer(plan).resolved) } test("check project's resolved") { @@ -71,11 +81,11 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { test("analyze project") { assert( - caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) === + caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) === Project(testRelation.output, testRelation)) assert( - caseSensitiveAnalyze( + caseSensitiveAnalyzer( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -88,13 +98,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage().toLowerCase.contains("cannot resolve")) assert( - caseInsensitiveAnalyze( + caseInsensitiveAnalyzer( Project(Seq(UnresolvedAttribute("TbL.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) assert( - caseInsensitiveAnalyze( + caseInsensitiveAnalyzer( Project(Seq(UnresolvedAttribute("tBl.a")), UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) === Project(testRelation.output, testRelation)) @@ -107,16 +117,13 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(e.getMessage == "Table Not Found: tAbLe") assert( - caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === - testRelation) + caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) === - testRelation) + caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation) assert( - caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) === - testRelation) + caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation) } def errorTest( @@ -177,7 +184,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { AttributeReference("d", DecimalType.Unlimited)(), AttributeReference("e", ShortType)()) - val plan = caseInsensitiveAnalyze( + val plan = caseInsensitiveAnalyzer( testRelation2.select( 'a / Literal(2) as 'div1, 'a / 'b as 'div2, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 69e5f6a07da7f..27ac398063d43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -117,7 +117,7 @@ class DataFrame protected[sql]( this(sqlContext, { val qe = sqlContext.executePlan(logicalPlan) if (sqlContext.conf.dataFrameEagerAnalysis) { - qe.analyzed // This should force analysis and throw errors if there are any + qe.assertAnalyzed() // This should force analysis and throw errors if there are any } qe }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 39f6c2f4bc8b4..a08c0f5ce3ff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -52,8 +52,9 @@ private[spark] object SQLConf { // This is used to set the default data source val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default" - // Whether to perform eager analysis on a DataFrame. - val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis" + // Whether to perform eager analysis when constructing a dataframe. + // Set to false when debugging requires the ability to look at invalid query plans. + val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis" object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4bdaa023914b8..ce800e0754559 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -114,7 +114,6 @@ class SQLContext(@transient val sparkContext: SparkContext) new Analyzer(catalog, functionRegistry, caseSensitive = true) { override val extendedResolutionRules = ExtractPythonUdfs :: - sources.PreWriteCheck(catalog) :: sources.PreInsertCastAndRename :: Nil } @@ -1057,6 +1056,13 @@ class SQLContext(@transient val sparkContext: SparkContext) Batch("Add exchange", Once, AddExchange(self)) :: Nil } + @transient + protected[sql] lazy val checkAnalysis = new CheckAnalysis { + override val extendedCheckRules = Seq( + sources.PreWriteCheck(catalog) + ) + } + /** * :: DeveloperApi :: * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -1064,9 +1070,13 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi protected[sql] class QueryExecution(val logical: LogicalPlan) { + def assertAnalyzed(): Unit = checkAnalysis(analyzed) lazy val analyzed: LogicalPlan = analyzer(logical) - lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed) + lazy val withCachedData: LogicalPlan = { + assertAnalyzed + cacheManager.useCachedData(analyzed) + } lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) // TODO: Don't just pick the first one... diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala index 36a9c0bdc41e6..8440581074877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala @@ -78,10 +78,10 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { /** * A rule to do various checks before inserting into or writing to a data source table. */ -private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan] { +private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) { def failAnalysis(msg: String) = { throw new AnalysisException(msg) } - def apply(plan: LogicalPlan): LogicalPlan = { + def apply(plan: LogicalPlan): Unit = { plan.foreach { case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) => @@ -93,7 +93,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan val srcRelations = query.collect { case LogicalRelation(src: BaseRelation) => src } - if (srcRelations.exists(src => src == t)) { + if (srcRelations.contains(t)) { failAnalysis( "Cannot insert overwrite into table that is also being read from.") } else { @@ -119,7 +119,7 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan val srcRelations = query.collect { case LogicalRelation(src: BaseRelation) => src } - if (srcRelations.exists(src => src == dest)) { + if (srcRelations.contains(dest)) { failAnalysis( s"Cannot overwrite table $tableName that is also being read from.") } else { @@ -134,7 +134,5 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends Rule[LogicalPlan case _ => // OK } - - plan } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 0ec6881d7afe6..91c6367371f15 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -30,7 +30,6 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter { override protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, caseSensitive = false) { override val extendedResolutionRules = - PreWriteCheck(catalog) :: PreInsertCastAndRename :: Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 5682e5a2bcea9..b5b16f9546691 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -205,7 +205,7 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll { val message = intercept[AnalysisException] { sql( s""" - |INSERT OVERWRITE TABLE oneToTen SELECT a FROM jt + |INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt """.stripMargin) }.getMessage assert( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 2e205e67c0fdd..c439dfe0a71f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -268,7 +268,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.PreInsertionCasts :: ExtractPythonUdfs :: ResolveUdtfsAlias :: - sources.PreWriteCheck(catalog) :: sources.PreInsertCastAndRename :: Nil } From cf2e41653de778dc8db8b03385a053aae1152e19 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 23 Feb 2015 22:08:44 -0800 Subject: [PATCH 05/28] [SPARK-5958][MLLIB][DOC] update block matrix user guide * Removed SVD code from examples. * Corrected Java API doc link. * Updated variable names: `AtransposeA` -> `ata`. * Minor changes. brkyvz Author: Xiangrui Meng Closes #4737 from mengxr/update-block-matrix-user-guide and squashes the following commits: 70f53ac [Xiangrui Meng] update block matrix user guide --- docs/mllib-data-types.md | 41 +++++++++++++++------------------------- 1 file changed, 15 insertions(+), 26 deletions(-) diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md index 24d22b9bcdfa4..fe6c1bf7bfd99 100644 --- a/docs/mllib-data-types.md +++ b/docs/mllib-data-types.md @@ -298,23 +298,22 @@ In general the use of non-deterministic RDDs can lead to errors. ### BlockMatrix -A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where `MatrixBlock` is +A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`. -`BlockMatrix` supports methods such as `.add` and `.multiply` with another `BlockMatrix`. -`BlockMatrix` also has a helper function `.validate` which can be used to debug whether the +`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`. +`BlockMatrix` also has a helper function `validate` which can be used to check whether the `BlockMatrix` is set up properly.
A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be -most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` using `.toBlockMatrix()`. -`.toBlockMatrix()` will create blocks of size 1024 x 1024. Users may change the sizes of their blocks -by supplying the values through `.toBlockMatrix(rowsPerBlock, colsPerBlock)`. +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. {% highlight scala %} -import org.apache.spark.mllib.linalg.SingularValueDecomposition import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry} val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries @@ -323,29 +322,24 @@ val coordMat: CoordinateMatrix = new CoordinateMatrix(entries) // Transform the CoordinateMatrix to a BlockMatrix val matA: BlockMatrix = coordMat.toBlockMatrix().cache() -// validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. // Nothing happens if it is valid. -matA.validate +matA.validate() // Calculate A^T A. -val AtransposeA = matA.transpose.multiply(matA) - -// get SVD of 2 * A -val A2 = matA.add(matA) -val svd = A2.toIndexedRowMatrix().computeSVD(20, false, 1e-9) +val ata = matA.transpose.multiply(matA) {% endhighlight %}
-A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be -most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` using `.toBlockMatrix()`. -`.toBlockMatrix()` will create blocks of size 1024 x 1024. Users may change the sizes of their blocks -by supplying the values through `.toBlockMatrix(rowsPerBlock, colsPerBlock)`. +A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be +most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`. +`toBlockMatrix` creates blocks of size 1024 x 1024 by default. +Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`. {% highlight java %} import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.mllib.linalg.SingularValueDecomposition; import org.apache.spark.mllib.linalg.distributed.BlockMatrix; import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix; import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix; @@ -356,17 +350,12 @@ CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd()); // Transform the CoordinateMatrix to a BlockMatrix BlockMatrix matA = coordMat.toBlockMatrix().cache(); -// validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. +// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid. // Nothing happens if it is valid. matA.validate(); // Calculate A^T A. -BlockMatrix AtransposeA = matA.transpose().multiply(matA); - -// get SVD of 2 * A -BlockMatrix A2 = matA.add(matA); -SingularValueDecomposition svd = - A2.toIndexedRowMatrix().computeSVD(20, false, 1e-9); +BlockMatrix ata = matA.transpose().multiply(matA); {% endhighlight %}
From 840333133396d443e747f62fce9967f7681fb276 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 24 Feb 2015 10:45:38 -0800 Subject: [PATCH 06/28] [SPARK-5968] [SQL] Suppresses ParquetOutputCommitter WARN logs Please refer to the [JIRA ticket] [1] for the motivation. [1]: https://issues.apache.org/jira/browse/SPARK-5968 [Review on Reviewable](https://reviewable.io/reviews/apache/spark/4744) Author: Cheng Lian Closes #4744 from liancheng/spark-5968 and squashes the following commits: caac6a8 [Cheng Lian] Suppresses ParquetOutputCommitter WARN logs --- .../apache/spark/sql/parquet/ParquetRelation.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index b0db9943a506c..a0d1005c0cae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -18,11 +18,12 @@ package org.apache.spark.sql.parquet import java.io.IOException +import java.util.logging.Level import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction -import parquet.hadoop.ParquetOutputFormat +import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType @@ -91,7 +92,7 @@ private[sql] object ParquetRelation { // checks first to see if there's any handlers already set // and if not it creates them. If this method executes prior // to that class being loaded then: - // 1) there's no handlers installed so there's none to + // 1) there's no handlers installed so there's none to // remove. But when it IS finally loaded the desired affect // of removing them is circumvented. // 2) The parquet.Log static initializer calls setUseParentHanders(false) @@ -99,7 +100,7 @@ private[sql] object ParquetRelation { // // Therefore we need to force the class to be loaded. // This should really be resolved by Parquet. - Class.forName(classOf[parquet.Log].getName()) + Class.forName(classOf[parquet.Log].getName) // Note: Logger.getLogger("parquet") has a default logger // that appends to Console which needs to be cleared. @@ -108,6 +109,11 @@ private[sql] object ParquetRelation { // TODO(witgo): Need to set the log level ? // if(parquetLogger.getLevel != null) parquetLogger.setLevel(null) if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true) + + // Disables WARN log message in ParquetOutputCommitter. + // See https://issues.apache.org/jira/browse/SPARK-5968 for details + Class.forName(classOf[ParquetOutputCommitter].getName) + java.util.logging.Logger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF) } // The element type for the RDDs that this relation maps to. From 0a59e45e2f2e6f00ccd5f10c79f629fb796fd8d0 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Feb 2015 10:49:51 -0800 Subject: [PATCH 07/28] [SPARK-5910][SQL] Support for as in selectExpr Author: Michael Armbrust Closes #4736 from marmbrus/asExprs and squashes the following commits: 5ba97e4 [Michael Armbrust] [SPARK-5910][SQL] Support for as in selectExpr --- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 +- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index b16aff99af1c5..c363a5efacde8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -40,7 +40,7 @@ class SqlParser extends AbstractSparkSQLParser { def parseExpression(input: String): Expression = { // Initialize the Keywords. lexical.initialize(reservedWords) - phrase(expression)(new lexical.Scanner(input)) match { + phrase(projection)(new lexical.Scanner(input)) match { case Success(plan, _) => plan case failureOrError => sys.error(failureOrError.toString) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6b9b3a8425964..e71e9bee3a6d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -130,6 +130,12 @@ class DataFrameSuite extends QueryTest { testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq) } + test("selectExpr with alias") { + checkAnswer( + testData.selectExpr("key as k").select("k"), + testData.select("key").collect().toSeq) + } + test("filterExpr") { checkAnswer( testData.filter("key > 90"), From 201236628a344194f7c20ba8e9afeeaefbe9318c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Feb 2015 10:52:18 -0800 Subject: [PATCH 08/28] [SPARK-5532][SQL] Repartition should not use external rdd representation Author: Michael Armbrust Closes #4738 from marmbrus/udtRepart and squashes the following commits: c06d7b5 [Michael Armbrust] fix compilation 91c8829 [Michael Armbrust] [SQL][SPARK-5532] Repartition should not use external rdd representation --- .../scala/org/apache/spark/sql/DataFrame.scala | 5 +++-- .../spark/sql/execution/debug/package.scala | 1 + .../apache/spark/sql/UserDefinedTypeSuite.scala | 16 +++++++++++++++- 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 27ac398063d43..04bf5d9b0f931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -799,14 +799,15 @@ class DataFrame protected[sql]( * Returns the number of rows in the [[DataFrame]]. * @group action */ - override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) + override def count(): Long = groupBy().count().collect().head.getLong(0) /** * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. * @group rdd */ override def repartition(numPartitions: Int): DataFrame = { - sqlContext.createDataFrame(rdd.repartition(numPartitions), schema) + sqlContext.createDataFrame( + queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 73162b22fa9cd..ffe388cfa9532 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -168,6 +168,7 @@ package object debug { case (_: Short, ShortType) => case (_: Boolean, BooleanType) => case (_: Double, DoubleType) => + case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType) case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 5f21d990e2e5b..9c098df24c65f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql +import java.io.File + import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -91,4 +92,17 @@ class UserDefinedTypeSuite extends QueryTest { sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } + + + test("UDTs with Parquet") { + val tempDir = File.createTempFile("parquet", "test") + tempDir.delete() + pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath) + } + + test("Repartition UDTs with Parquet") { + val tempDir = File.createTempFile("parquet", "test") + tempDir.delete() + pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) + } } From 64d2c01ff1048de83b9b8efce987b55e457298f9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Feb 2015 11:02:47 -0800 Subject: [PATCH 09/28] [Spark-5967] [UI] Correctly clean JobProgressListener.stageIdToActiveJobIds Patch should be self-explanatory pwendell JoshRosen Author: Tathagata Das Closes #4741 from tdas/SPARK-5967 and squashes the following commits: 653b5bb [Tathagata Das] Fixed the fix and added test e2de972 [Tathagata Das] Clear stages which have no corresponding active jobs. --- .../spark/ui/jobs/JobProgressListener.scala | 3 +++ .../ui/jobs/JobProgressListenerSuite.scala | 22 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 0b6fe70bd2062..937d95a934b59 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -203,6 +203,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { for (stageId <- jobData.stageIds) { stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage => jobsUsingStage.remove(jobEnd.jobId) + if (jobsUsingStage.isEmpty) { + stageIdToActiveJobIds.remove(stageId) + } stageIdToInfo.get(stageId).foreach { stageInfo => if (stageInfo.submissionTime.isEmpty) { // if this stage is pending, it won't complete, so mark it as "skipped": diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 6019282d2fb70..730a4b54f5aa1 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -88,6 +88,28 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc listener.completedStages.map(_.stageId).toSet should be (Set(50, 49, 48, 47, 46)) } + test("test clearing of stageIdToActiveJobs") { + val conf = new SparkConf() + conf.set("spark.ui.retainedStages", 5.toString) + val listener = new JobProgressListener(conf) + val jobId = 0 + val stageIds = 1 to 50 + // Start a job with 50 stages + listener.onJobStart(createJobStartEvent(jobId, stageIds)) + for (stageId <- stageIds) { + listener.onStageSubmitted(createStageStartEvent(stageId)) + } + listener.stageIdToActiveJobIds.size should be > 0 + + // Complete the stages and job + for (stageId <- stageIds) { + listener.onStageCompleted(createStageEndEvent(stageId, failed = false)) + } + listener.onJobEnd(createJobEndEvent(jobId, false)) + assertActiveJobsStateIsEmpty(listener) + listener.stageIdToActiveJobIds.size should be (0) + } + test("test LRU eviction of jobs") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) From 6d2caa576fcdc5c848d1472b09c685b3871e220e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 24 Feb 2015 11:08:07 -0800 Subject: [PATCH 10/28] [SPARK-5965] Standalone Worker UI displays {{USER_JAR}} For screenshot see: https://issues.apache.org/jira/browse/SPARK-5965 This was caused by 20a6013106b56a1a1cc3e8cda092330ffbe77cc3. Author: Andrew Or Closes #4739 from andrewor14/user-jar-blocker and squashes the following commits: 23c4a9e [Andrew Or] Use right argument --- .../scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 327b905032800..720f13bfa829b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -134,7 +134,7 @@ private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { def driverRow(driver: DriverRunner): Seq[Node] = { {driver.driverId} - {driver.driverDesc.command.arguments(1)} + {driver.driverDesc.command.arguments(2)} {driver.finalState.getOrElse(DriverState.RUNNING)} {driver.driverDesc.cores.toString} From 105791e35cee694f3b2ac1e06758650fe44e2c71 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 24 Feb 2015 11:38:59 -0800 Subject: [PATCH 11/28] [MLLIB] Change x_i to y_i in Variance's user guide Variance is calculated on labels/responses. Author: Xiangrui Meng Closes #4740 from mengxr/patch-1 and squashes the following commits: 673317b [Xiangrui Meng] [MLLIB] Change x_i to y_i in Variance's user guide --- docs/mllib-decision-tree.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md index d1537def851e7..6675133a810db 100644 --- a/docs/mllib-decision-tree.md +++ b/docs/mllib-decision-tree.md @@ -54,8 +54,8 @@ impurity measure for regression (variance). Variance Regression - $\frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$$y_i$ is label for an instance, - $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N x_i$. + $\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$$y_i$ is label for an instance, + $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N y_i$. From c5ba975ee85521f708ebeec81144347cf1b40fba Mon Sep 17 00:00:00 2001 From: Judy Date: Tue, 24 Feb 2015 20:50:16 +0000 Subject: [PATCH 12/28] [Spark-5708] Add Slf4jSink to Spark Metrics Add Slf4jSink to Spark Metrics using Coda Hale's SlfjReporter. This sends metrics to log4j, allowing spark users to reuse log4j pipeline for metrics collection. Reviewed existing unit tests and didn't see any sink-related tests. Please advise on if tests should be added. Author: Judy Author: judynash Closes #4644 from judynash/master and squashes the following commits: 57ef214 [judynash] doc clarification and indent fixes a751a66 [Judy] Spark-5708: Add Slf4jSink to Spark Metrics --- conf/metrics.properties.template | 9 +++ .../apache/spark/metrics/sink/Slf4jSink.scala | 68 +++++++++++++++++++ docs/monitoring.md | 1 + 3 files changed, 78 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 464c14457e53f..2e0cb5db170ac 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -122,6 +122,15 @@ #worker.sink.csv.unit=minutes +# Enable Slf4jSink for all instances by class name +#*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink + +# Polling period for Slf4JSink +#*.sink.sl4j.period=1 + +#*.sink.sl4j.unit=minutes + + # Enable jvm source for instance master, worker, driver and executor #master.source.jvm.class=org.apache.spark.metrics.source.JvmSource diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala new file mode 100644 index 0000000000000..e8b3074e8f1a6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.util.Properties +import java.util.concurrent.TimeUnit + +import com.codahale.metrics.{Slf4jReporter, MetricRegistry} + +import org.apache.spark.SecurityManager +import org.apache.spark.metrics.MetricsSystem + +private[spark] class Slf4jSink( + val property: Properties, + val registry: MetricRegistry, + securityMgr: SecurityManager) + extends Sink { + val SLF4J_DEFAULT_PERIOD = 10 + val SLF4J_DEFAULT_UNIT = "SECONDS" + + val SLF4J_KEY_PERIOD = "period" + val SLF4J_KEY_UNIT = "unit" + + val pollPeriod = Option(property.getProperty(SLF4J_KEY_PERIOD)) match { + case Some(s) => s.toInt + case None => SLF4J_DEFAULT_PERIOD + } + + val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT) + } + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val reporter: Slf4jReporter = Slf4jReporter.forRegistry(registry) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .build() + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } + + override def report() { + reporter.report() + } +} + diff --git a/docs/monitoring.md b/docs/monitoring.md index 7a5cadc171d6d..009a344dff4bb 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -176,6 +176,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the * `JmxSink`: Registers metrics for viewing in a JMX console. * `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data. * `GraphiteSink`: Sends metrics to a Graphite node. +* `Slf4jSink`: Sends metrics to slf4j as log entries. Spark also supports a Ganglia sink which is not included in the default build due to licensing restrictions: From a2b9137923e0ba328da8fff2fbbfcf2abf50b033 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 24 Feb 2015 13:39:29 -0800 Subject: [PATCH 13/28] [SPARK-5952][SQL] Lock when using hive metastore client Author: Michael Armbrust Closes #4746 from marmbrus/hiveLock and squashes the following commits: 8b871cf [Michael Armbrust] [SPARK-5952][SQL] Lock when using hive metastore client --- .../spark/sql/hive/HiveMetastoreCatalog.scala | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f7ad2efc9544e..2cc8d65d3cb79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -52,6 +52,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with /** Connection to hive metastore. Usages should lock on `this`. */ protected[hive] val client = Hive.get(hive.hiveconf) + /** Usages should lock on `this`. */ protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf) // TODO: Use this everywhere instead of tuples or databaseName, tableName,. @@ -65,7 +66,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() { override def load(in: QualifiedTableName): LogicalPlan = { logDebug(s"Creating new cached data source for $in") - val table = client.getTable(in.database, in.name) + val table = synchronized { + client.getTable(in.database, in.name) + } val schemaString = table.getProperty("spark.sql.sources.schema") val userSpecifiedSchema = if (schemaString == null) { @@ -134,15 +137,18 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } - def hiveDefaultTableFilePath(tableName: String): String = { + def hiveDefaultTableFilePath(tableName: String): String = synchronized { val currentDatabase = client.getDatabase(hive.sessionState.getCurrentDatabase) + hiveWarehouse.getTablePath(currentDatabase, tableName).toString } - def tableExists(tableIdentifier: Seq[String]): Boolean = { + def tableExists(tableIdentifier: Seq[String]): Boolean = synchronized { val tableIdent = processTableIdentifier(tableIdentifier) - val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse( - hive.sessionState.getCurrentDatabase) + val databaseName = + tableIdent + .lift(tableIdent.size - 2) + .getOrElse(hive.sessionState.getCurrentDatabase) val tblName = tableIdent.last client.getTable(databaseName, tblName, false) != null } @@ -219,7 +225,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with } } - override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { + override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = synchronized { val dbName = if (!caseSensitive) { if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None } else { From da505e59274d1c838653c1109db65ad374e65304 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 24 Feb 2015 14:50:00 -0800 Subject: [PATCH 14/28] [SPARK-5973] [PySpark] fix zip with two RDDs with AutoBatchedSerializer Author: Davies Liu Closes #4745 from davies/fix_zip and squashes the following commits: 2124b2c [Davies Liu] Update tests.py b5c828f [Davies Liu] increase the number of records c1e40fd [Davies Liu] fix zip with two RDDs with AutoBatchedSerializer --- python/pyspark/rdd.py | 2 +- python/pyspark/tests.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ba2347ae76844..d3148de6f41a3 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -1950,7 +1950,7 @@ def batch_as(rdd, batchSize): my_batch = get_batch_size(self._jrdd_deserializer) other_batch = get_batch_size(other._jrdd_deserializer) - if my_batch != other_batch: + if my_batch != other_batch or not my_batch: # use the smallest batchSize for both of them batchSize = min(my_batch, other_batch) if batchSize <= 0: diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 52e82091c9f81..06ba2b461d53e 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -543,6 +543,12 @@ def test_zip_with_different_serializers(self): # regression test for bug in _reserializer() self.assertEqual(cnt, t.zip(rdd).count()) + def test_zip_with_different_object_sizes(self): + # regress test for SPARK-5973 + a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i) + b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i) + self.assertEqual(10000, a.zip(b).count()) + def test_zip_with_different_number_of_items(self): a = self.sc.parallelize(range(5), 2) # different number of partitions From 2a0fe34891882e0fde1b5722d8227aa99acc0f1f Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 24 Feb 2015 15:13:22 -0800 Subject: [PATCH 15/28] [SPARK-5436] [MLlib] Validate GradientBoostedTrees using runWithValidation One can early stop if the decrease in error rate is lesser than a certain tol or if the error increases if the training data is overfit. This introduces a new method runWithValidation which takes in a pair of RDD's , one for the training data and the other for the validation. Author: MechCoder Closes #4677 from MechCoder/spark-5436 and squashes the following commits: 1bb21d4 [MechCoder] Combine regression and classification tests into a single one e4d799b [MechCoder] Addresses indentation and doc comments b48a70f [MechCoder] COSMIT b928a19 [MechCoder] Move validation while training section under usage tips fad9b6e [MechCoder] Made the following changes 1. Add section to documentation 2. Return corresponding to bestValidationError 3. Allow negative tolerance. 55e5c3b [MechCoder] One liner for prevValidateError 3e74372 [MechCoder] TST: Add test for classification 77549a9 [MechCoder] [SPARK-5436] Validate GradientBoostedTrees using runWithValidation --- docs/mllib-ensembles.md | 11 +++ .../mllib/tree/GradientBoostedTrees.scala | 75 +++++++++++++++++-- .../tree/configuration/BoostingStrategy.scala | 6 +- .../tree/GradientBoostedTreesSuite.scala | 36 +++++++++ 4 files changed, 122 insertions(+), 6 deletions(-) diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md index fb90b7039971c..00040e6073d0d 100644 --- a/docs/mllib-ensembles.md +++ b/docs/mllib-ensembles.md @@ -427,6 +427,17 @@ We omit some decision tree parameters since those are covered in the [decision t * **`algo`**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter. +#### Validation while training + +Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while +training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD's as arguments, the +first one being the training dataset and the second being the validation dataset. + +The training is stopped when the improvement in the validation error is not more than a certain tolerance +(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error +decreases initially and later increases. There might be cases in which the validation error does not change monotonically, +and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of +iterations. ### Examples diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index 61f6b1313f82e..b4466ff40937f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = { val algo = boostingStrategy.treeStrategy.algo algo match { - case Regression => GradientBoostedTrees.boost(input, boostingStrategy) + case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false) case Classification => // Map labels to -1, +1 so binary classification can be treated as regression. val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features)) - GradientBoostedTrees.boost(remappedInput, boostingStrategy) + GradientBoostedTrees.boost(remappedInput, + remappedInput, boostingStrategy, validate=false) case _ => throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") } @@ -76,8 +77,46 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy) def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { run(input.rdd) } -} + /** + * Method to validate a gradient boosting model + * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + * @param validationInput Validation dataset: + RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. + Should be different from and follow the same distribution as input. + e.g., these two datasets could be created from an original dataset + by using [[org.apache.spark.rdd.RDD.randomSplit()]] + * @return a gradient boosted trees model that can be used for prediction + */ + def runWithValidation( + input: RDD[LabeledPoint], + validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = { + val algo = boostingStrategy.treeStrategy.algo + algo match { + case Regression => GradientBoostedTrees.boost( + input, validationInput, boostingStrategy, validate=true) + case Classification => + // Map labels to -1, +1 so binary classification can be treated as regression. + val remappedInput = input.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + val remappedValidationInput = validationInput.map( + x => new LabeledPoint((x.label * 2) - 1, x.features)) + GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy, + validate=true) + case _ => + throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.") + } + } + + /** + * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]]. + */ + def runWithValidation( + input: JavaRDD[LabeledPoint], + validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = { + runWithValidation(input.rdd, validationInput.rdd) + } +} object GradientBoostedTrees extends Logging { @@ -108,12 +147,16 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. * @param input training dataset + * @param validationInput validation dataset, ignored if validate is set to false. * @param boostingStrategy boosting parameters + * @param validate whether or not to use the validation dataset. * @return a gradient boosted trees model that can be used for prediction */ private def boost( input: RDD[LabeledPoint], - boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = { + validationInput: RDD[LabeledPoint], + boostingStrategy: BoostingStrategy, + validate: Boolean): GradientBoostedTreesModel = { val timer = new TimeTracker() timer.start("total") @@ -129,6 +172,7 @@ object GradientBoostedTrees extends Logging { val learningRate = boostingStrategy.learningRate // Prepare strategy for individual trees, which use regression with variance impurity. val treeStrategy = boostingStrategy.treeStrategy.copy + val validationTol = boostingStrategy.validationTol treeStrategy.algo = Regression treeStrategy.impurity = Variance treeStrategy.assertValid() @@ -152,13 +196,16 @@ object GradientBoostedTrees extends Logging { baseLearnerWeights(0) = 1.0 val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0)) logDebug("error of gbt = " + loss.computeError(startingModel, input)) + // Note: A model of type regression is used since we require raw prediction timer.stop("building tree 0") + var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0 + var bestM = 1 + // psuedo-residual for second iteration data = input.map(point => LabeledPoint(loss.gradient(startingModel, point), point.features)) - var m = 1 while (m < numIterations) { timer.start(s"building tree $m") @@ -177,6 +224,23 @@ object GradientBoostedTrees extends Logging { val partialModel = new GradientBoostedTreesModel( Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1)) logDebug("error of gbt = " + loss.computeError(partialModel, input)) + + if (validate) { + // Stop training early if + // 1. Reduction in error is less than the validationTol or + // 2. If the error increases, that is if the model is overfit. + // We want the model returned corresponding to the best validation error. + val currentValidateError = loss.computeError(partialModel, validationInput) + if (bestValidateError - currentValidateError < validationTol) { + return new GradientBoostedTreesModel( + boostingStrategy.treeStrategy.algo, + baseLearners.slice(0, bestM), + baseLearnerWeights.slice(0, bestM)) + } else if (currentValidateError < bestValidateError) { + bestValidateError = currentValidateError + bestM = m + 1 + } + } // Update data with pseudo-residuals data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point), point.features)) @@ -191,4 +255,5 @@ object GradientBoostedTrees extends Logging { new GradientBoostedTreesModel( boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights) } + } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index ed8e6a796f8c4..664c8df019233 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -34,6 +34,9 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} * weak hypotheses used in the final model. * @param learningRate Learning rate for shrinking the contribution of each estimator. The * learning rate should be between in the interval (0, 1] + * @param validationTol Useful when runWithValidation is used. If the error rate on the + * validation input between two iterations is less than the validationTol + * then stop. Ignored when [[run]] is used. */ @Experimental case class BoostingStrategy( @@ -42,7 +45,8 @@ case class BoostingStrategy( @BeanProperty var loss: Loss, // Optional boosting parameters @BeanProperty var numIterations: Int = 100, - @BeanProperty var learningRate: Double = 0.1) extends Serializable { + @BeanProperty var learningRate: Double = 0.1, + @BeanProperty var validationTol: Double = 1e-5) extends Serializable { /** * Check validity of parameters. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index bde47606eb001..b437aeaaf0547 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -158,6 +158,40 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } } + + test("runWithValidation stops early and performs better on a validation dataset") { + // Set numIterations large enough so that it stops early. + val numIterations = 20 + val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2) + val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2) + + val algos = Array(Regression, Regression, Classification) + val losses = Array(SquaredError, AbsoluteError, LogLoss) + (algos zip losses) map { + case (algo, loss) => { + val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, + categoricalFeaturesInfo = Map.empty) + val boostingStrategy = + new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) + val gbtValidate = new GradientBoostedTrees(boostingStrategy) + .runWithValidation(trainRdd, validateRdd) + assert(gbtValidate.numTrees !== numIterations) + + // Test that it performs better on the validation dataset. + val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy) + val (errorWithoutValidation, errorWithValidation) = { + if (algo == Classification) { + val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) + (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) + } else { + (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd)) + } + } + assert(errorWithValidation <= errorWithoutValidation) + } + } + } + } private object GradientBoostedTreesSuite { @@ -166,4 +200,6 @@ private object GradientBoostedTreesSuite { val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75)) val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100) + val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120) + val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80) } From f816e73902b8ca28e24bf1f79a70533f75f239db Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 25 Feb 2015 08:34:55 +0800 Subject: [PATCH 16/28] [SPARK-5751] [SQL] [WIP] Revamped HiveThriftServer2Suite for robustness **NOTICE** Do NOT merge this, as we're waiting for #3881 to be merged. `HiveThriftServer2Suite` has been notorious for its flakiness for a while. This was mostly due to spawning and communicate with external server processes. This PR revamps this test suite for better robustness: 1. Fixes a racing condition occurred while using `tail -f` to check log file It's possible that the line we are looking for has already been printed into the log file before we start the `tail -f` process. This PR uses `tail -n +0 -f` to ensure all lines are checked. 2. Retries up to 3 times if the server fails to start In most of the cases, the server fails to start because of port conflict. This PR no longer asks the system to choose an available TCP port, but uses a random port first, and retries up to 3 times if the server fails to start. 3. A server instance is reused among all test cases within a single suite The original `HiveThriftServer2Suite` is splitted into two test suites, `HiveThriftBinaryServerSuite` and `HiveThriftHttpServerSuite`. Each suite starts a `HiveThriftServer2` instance and reuses it for all of its test cases. **TODO** - [ ] Starts the Thrift server in foreground once #3881 is merged (adding `--foreground` flag to `spark-daemon.sh`) [Review on Reviewable](https://reviewable.io/reviews/apache/spark/4720) Author: Cheng Lian Closes #4720 from liancheng/revamp-thrift-server-tests and squashes the following commits: d6c80eb [Cheng Lian] Relaxes server startup timeout 6f14eb1 [Cheng Lian] Revamped HiveThriftServer2Suite for robustness --- .../thriftserver/HiveThriftServer2Suite.scala | 387 ----------------- .../HiveThriftServer2Suites.scala | 403 ++++++++++++++++++ 2 files changed, 403 insertions(+), 387 deletions(-) delete mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala create mode 100644 sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala deleted file mode 100644 index b52a51d11e4ad..0000000000000 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.thriftserver - -import java.io.File -import java.net.ServerSocket -import java.sql.{Date, DriverManager, Statement} -import java.util.concurrent.TimeoutException - -import scala.collection.JavaConversions._ -import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise} -import scala.sys.process.{Process, ProcessLogger} -import scala.util.Try - -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hive.jdbc.HiveDriver -import org.apache.hive.service.auth.PlainSaslHelper -import org.apache.hive.service.cli.GetInfoType -import org.apache.hive.service.cli.thrift.TCLIService.Client -import org.apache.hive.service.cli.thrift._ -import org.apache.thrift.protocol.TBinaryProtocol -import org.apache.thrift.transport.TSocket -import org.scalatest.FunSuite - -import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.util.getTempFilePath -import org.apache.spark.sql.hive.HiveShim - -/** - * Tests for the HiveThriftServer2 using JDBC. - * - * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be - * rebuilt after changing HiveThriftServer2 related code. - */ -class HiveThriftServer2Suite extends FunSuite with Logging { - Class.forName(classOf[HiveDriver].getCanonicalName) - - object TestData { - def getTestDataFilePath(name: String) = { - Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") - } - - val smallKv = getTestDataFilePath("small_kv.txt") - val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") - } - - def randomListeningPort = { - // Let the system to choose a random available port to avoid collision with other parallel - // builds. - val socket = new ServerSocket(0) - val port = socket.getLocalPort - socket.close() - port - } - - def withJdbcStatement( - serverStartTimeout: FiniteDuration = 1.minute, - httpMode: Boolean = false)( - f: Statement => Unit) { - val port = randomListeningPort - - startThriftServer(port, serverStartTimeout, httpMode) { - val jdbcUri = if (httpMode) { - s"jdbc:hive2://${"localhost"}:$port/" + - "default?hive.server2.transport.mode=http;hive.server2.thrift.http.path=cliservice" - } else { - s"jdbc:hive2://${"localhost"}:$port/" - } - - val user = System.getProperty("user.name") - val connection = DriverManager.getConnection(jdbcUri, user, "") - val statement = connection.createStatement() - - try { - f(statement) - } finally { - statement.close() - connection.close() - } - } - } - - def withCLIServiceClient( - serverStartTimeout: FiniteDuration = 1.minute)( - f: ThriftCLIServiceClient => Unit) { - val port = randomListeningPort - - startThriftServer(port) { - // Transport creation logics below mimics HiveConnection.createBinaryTransport - val rawTransport = new TSocket("localhost", port) - val user = System.getProperty("user.name") - val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) - val protocol = new TBinaryProtocol(transport) - val client = new ThriftCLIServiceClient(new Client(protocol)) - - transport.open() - - try { - f(client) - } finally { - transport.close() - } - } - } - - def startThriftServer( - port: Int, - serverStartTimeout: FiniteDuration = 1.minute, - httpMode: Boolean = false)( - f: => Unit) { - val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) - val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) - - val warehousePath = getTempFilePath("warehouse") - val metastorePath = getTempFilePath("metastore") - val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" - - val command = - if (httpMode) { - s"""$startScript - | --master local - | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost - | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port - | --driver-class-path ${sys.props("java.class.path")} - | --conf spark.ui.enabled=false - """.stripMargin.split("\\s+").toSeq - } else { - s"""$startScript - | --master local - | --hiveconf hive.root.logger=INFO,console - | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri - | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost - | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port - | --driver-class-path ${sys.props("java.class.path")} - | --conf spark.ui.enabled=false - """.stripMargin.split("\\s+").toSeq - } - - val serverRunning = Promise[Unit]() - val buffer = new ArrayBuffer[String]() - val LOGGING_MARK = - s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to " - var logTailingProcess: Process = null - var logFilePath: String = null - - def captureLogOutput(line: String): Unit = { - buffer += line - if (line.contains("ThriftBinaryCLIService listening on") || - line.contains("Started ThriftHttpCLIService in http")) { - serverRunning.success(()) - } - } - - def captureThriftServerOutput(source: String)(line: String): Unit = { - if (line.startsWith(LOGGING_MARK)) { - logFilePath = line.drop(LOGGING_MARK.length).trim - // Ensure that the log file is created so that the `tail' command won't fail - Try(new File(logFilePath).createNewFile()) - logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath") - .run(ProcessLogger(captureLogOutput, _ => ())) - } - } - - val env = Seq( - // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths - "SPARK_TESTING" -> "0") - - Process(command, None, env: _*).run(ProcessLogger( - captureThriftServerOutput("stdout"), - captureThriftServerOutput("stderr"))) - - try { - Await.result(serverRunning.future, serverStartTimeout) - f - } catch { - case cause: Exception => - cause match { - case _: TimeoutException => - logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause) - case _ => - } - logError( - s""" - |===================================== - |HiveThriftServer2Suite failure output - |===================================== - |HiveThriftServer2 command line: ${command.mkString(" ")} - |Binding port: $port - |System user: ${System.getProperty("user.name")} - | - |${buffer.mkString("\n")} - |========================================= - |End HiveThriftServer2Suite failure output - |========================================= - """.stripMargin, cause) - throw cause - } finally { - warehousePath.delete() - metastorePath.delete() - Process(stopScript, None, env: _*).run().exitValue() - // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. - Thread.sleep(3.seconds.toMillis) - Option(logTailingProcess).map(_.destroy()) - Option(logFilePath).map(new File(_).delete()) - } - } - - test("Test JDBC query execution") { - withJdbcStatement() { statement => - val queries = Seq( - "SET spark.sql.shuffle.partitions=3", - "DROP TABLE IF EXISTS test", - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", - "CACHE TABLE test") - - queries.foreach(statement.execute) - - assertResult(5, "Row count mismatch") { - val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") - resultSet.next() - resultSet.getInt(1) - } - } - } - - test("Test JDBC query execution in Http Mode") { - withJdbcStatement(httpMode = true) { statement => - val queries = Seq( - "SET spark.sql.shuffle.partitions=3", - "DROP TABLE IF EXISTS test", - "CREATE TABLE test(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", - "CACHE TABLE test") - - queries.foreach(statement.execute) - - assertResult(5, "Row count mismatch") { - val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") - resultSet.next() - resultSet.getInt(1) - } - } - } - - test("SPARK-3004 regression: result set containing NULL") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_null", - "CREATE TABLE test_null(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") - - queries.foreach(statement.execute) - - val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL") - - (0 until 5).foreach { _ => - resultSet.next() - assert(resultSet.getInt(1) === 0) - assert(resultSet.wasNull()) - } - - assert(!resultSet.next()) - } - } - - test("GetInfo Thrift API") { - withCLIServiceClient() { client => - val user = System.getProperty("user.name") - val sessionHandle = client.openSession(user, "") - - assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { - client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue - } - - assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { - client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue - } - - assertResult(true, "Spark version shouldn't be \"Unknown\"") { - val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue - logInfo(s"Spark version: $version") - version != "Unknown" - } - } - } - - test("Checks Hive version") { - withJdbcStatement() { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") - resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") - } - } - - test("Checks Hive version in Http Mode") { - withJdbcStatement(httpMode = true) { statement => - val resultSet = statement.executeQuery("SET spark.sql.hive.version") - resultSet.next() - assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") - } - } - - test("SPARK-4292 regression: result set iterator issue") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_4292", - "CREATE TABLE test_4292(key INT, val STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") - - queries.foreach(statement.execute) - - val resultSet = statement.executeQuery("SELECT key FROM test_4292") - - Seq(238, 86, 311, 27, 165).foreach { key => - resultSet.next() - assert(resultSet.getInt(1) === key) - } - - statement.executeQuery("DROP TABLE IF EXISTS test_4292") - } - } - - test("SPARK-4309 regression: Date type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_date", - "CREATE TABLE test_date(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") - - queries.foreach(statement.execute) - - assertResult(Date.valueOf("2011-01-01")) { - val resultSet = statement.executeQuery( - "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") - resultSet.next() - resultSet.getDate(1) - } - } - } - - test("SPARK-4407 regression: Complex type support") { - withJdbcStatement() { statement => - val queries = Seq( - "DROP TABLE IF EXISTS test_map", - "CREATE TABLE test_map(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") - - queries.foreach(statement.execute) - - assertResult("""{238:"val_238"}""") { - val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - - assertResult("""["238","val_238"]""") { - val resultSet = statement.executeQuery( - "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") - resultSet.next() - resultSet.getString(1) - } - } - } -} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala new file mode 100644 index 0000000000000..77ef37253e38f --- /dev/null +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -0,0 +1,403 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.thriftserver + +import java.io.File +import java.sql.{Date, DriverManager, Statement} + +import scala.collection.mutable.ArrayBuffer +import scala.concurrent.duration._ +import scala.concurrent.{Await, Promise} +import scala.sys.process.{Process, ProcessLogger} +import scala.util.{Random, Try} + +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hive.jdbc.HiveDriver +import org.apache.hive.service.auth.PlainSaslHelper +import org.apache.hive.service.cli.GetInfoType +import org.apache.hive.service.cli.thrift.TCLIService.Client +import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient +import org.apache.thrift.protocol.TBinaryProtocol +import org.apache.thrift.transport.TSocket +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.hive.HiveShim + +object TestData { + def getTestDataFilePath(name: String) = { + Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name") + } + + val smallKv = getTestDataFilePath("small_kv.txt") + val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt") +} + +class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { + override def mode = ServerMode.binary + + private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = { + // Transport creation logics below mimics HiveConnection.createBinaryTransport + val rawTransport = new TSocket("localhost", serverPort) + val user = System.getProperty("user.name") + val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport) + val protocol = new TBinaryProtocol(transport) + val client = new ThriftCLIServiceClient(new Client(protocol)) + + transport.open() + try f(client) finally transport.close() + } + + test("GetInfo Thrift API") { + withCLIServiceClient { client => + val user = System.getProperty("user.name") + val sessionHandle = client.openSession(user, "") + + assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue + } + + assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") { + client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue + } + + assertResult(true, "Spark version shouldn't be \"Unknown\"") { + val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue + logInfo(s"Spark version: $version") + version != "Unknown" + } + } + } + + test("JDBC query execution") { + withJdbcStatement { statement => + val queries = Seq( + "SET spark.sql.shuffle.partitions=3", + "DROP TABLE IF EXISTS test", + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", + "CACHE TABLE test") + + queries.foreach(statement.execute) + + assertResult(5, "Row count mismatch") { + val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") + resultSet.next() + resultSet.getInt(1) + } + } + } + + test("Checks Hive version") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + } + } + + test("SPARK-3004 regression: result set containing NULL") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_null", + "CREATE TABLE test_null(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL") + + (0 until 5).foreach { _ => + resultSet.next() + assert(resultSet.getInt(1) === 0) + assert(resultSet.wasNull()) + } + + assert(!resultSet.next()) + } + } + + test("SPARK-4292 regression: result set iterator issue") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_4292", + "CREATE TABLE test_4292(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292") + + queries.foreach(statement.execute) + + val resultSet = statement.executeQuery("SELECT key FROM test_4292") + + Seq(238, 86, 311, 27, 165).foreach { key => + resultSet.next() + assert(resultSet.getInt(1) === key) + } + + statement.executeQuery("DROP TABLE IF EXISTS test_4292") + } + } + + test("SPARK-4309 regression: Date type support") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_date", + "CREATE TABLE test_date(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date") + + queries.foreach(statement.execute) + + assertResult(Date.valueOf("2011-01-01")) { + val resultSet = statement.executeQuery( + "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1") + resultSet.next() + resultSet.getDate(1) + } + } + } + + test("SPARK-4407 regression: Complex type support") { + withJdbcStatement { statement => + val queries = Seq( + "DROP TABLE IF EXISTS test_map", + "CREATE TABLE test_map(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") + + queries.foreach(statement.execute) + + assertResult("""{238:"val_238"}""") { + val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + + assertResult("""["238","val_238"]""") { + val resultSet = statement.executeQuery( + "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1") + resultSet.next() + resultSet.getString(1) + } + } + } +} + +class HiveThriftHttpServerSuite extends HiveThriftJdbcTest { + override def mode = ServerMode.http + + test("JDBC query execution") { + withJdbcStatement { statement => + val queries = Seq( + "SET spark.sql.shuffle.partitions=3", + "DROP TABLE IF EXISTS test", + "CREATE TABLE test(key INT, val STRING)", + s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test", + "CACHE TABLE test") + + queries.foreach(statement.execute) + + assertResult(5, "Row count mismatch") { + val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test") + resultSet.next() + resultSet.getInt(1) + } + } + } + + test("Checks Hive version") { + withJdbcStatement { statement => + val resultSet = statement.executeQuery("SET spark.sql.hive.version") + resultSet.next() + assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}") + } + } +} + +object ServerMode extends Enumeration { + val binary, http = Value +} + +abstract class HiveThriftJdbcTest extends HiveThriftServer2Test { + Class.forName(classOf[HiveDriver].getCanonicalName) + + private def jdbcUri = if (mode == ServerMode.http) { + s"""jdbc:hive2://localhost:$serverPort/ + |default? + |hive.server2.transport.mode=http; + |hive.server2.thrift.http.path=cliservice + """.stripMargin.split("\n").mkString.trim + } else { + s"jdbc:hive2://localhost:$serverPort/" + } + + protected def withJdbcStatement(f: Statement => Unit): Unit = { + val connection = DriverManager.getConnection(jdbcUri, user, "") + val statement = connection.createStatement() + + try f(statement) finally { + statement.close() + connection.close() + } + } +} + +abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging { + def mode: ServerMode.Value + + private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$") + private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to " + + private val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator) + private val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator) + + private var listeningPort: Int = _ + protected def serverPort: Int = listeningPort + + protected def user = System.getProperty("user.name") + + private var warehousePath: File = _ + private var metastorePath: File = _ + private def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true" + + private var logPath: File = _ + private var logTailingProcess: Process = _ + private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String] + + private def serverStartCommand(port: Int) = { + val portConf = if (mode == ServerMode.binary) { + ConfVars.HIVE_SERVER2_THRIFT_PORT + } else { + ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT + } + + s"""$startScript + | --master local + | --hiveconf hive.root.logger=INFO,console + | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri + | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath + | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost + | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode + | --hiveconf $portConf=$port + | --driver-class-path ${sys.props("java.class.path")} + | --conf spark.ui.enabled=false + """.stripMargin.split("\\s+").toSeq + } + + private def startThriftServer(port: Int, attempt: Int) = { + warehousePath = util.getTempFilePath("warehouse") + metastorePath = util.getTempFilePath("metastore") + logPath = null + logTailingProcess = null + + val command = serverStartCommand(port) + + diagnosisBuffer ++= + s""" + |### Attempt $attempt ### + |HiveThriftServer2 command line: $command + |Listening port: $port + |System user: $user + """.stripMargin.split("\n") + + logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt") + + logPath = Process(command, None, "SPARK_TESTING" -> "0").lines.collectFirst { + case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length)) + }.getOrElse { + throw new RuntimeException("Failed to find HiveThriftServer2 log file.") + } + + val serverStarted = Promise[Unit]() + + // Ensures that the following "tail" command won't fail. + logPath.createNewFile() + logTailingProcess = + // Using "-n +0" to make sure all lines in the log file are checked. + Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger( + (line: String) => { + diagnosisBuffer += line + + if (line.contains("ThriftBinaryCLIService listening on") || + line.contains("Started ThriftHttpCLIService in http")) { + serverStarted.trySuccess(()) + } else if (line.contains("HiveServer2 is stopped")) { + // This log line appears when the server fails to start and terminates gracefully (e.g. + // because of port contention). + serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2")) + } + })) + + Await.result(serverStarted.future, 2.minute) + } + + private def stopThriftServer(): Unit = { + // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while. + Process(stopScript, None).run().exitValue() + Thread.sleep(3.seconds.toMillis) + + warehousePath.delete() + warehousePath = null + + metastorePath.delete() + metastorePath = null + + Option(logPath).foreach(_.delete()) + logPath = null + + Option(logTailingProcess).foreach(_.destroy()) + logTailingProcess = null + } + + private def dumpLogs(): Unit = { + logError( + s""" + |===================================== + |HiveThriftServer2Suite failure output + |===================================== + |${diagnosisBuffer.mkString("\n")} + |========================================= + |End HiveThriftServer2Suite failure output + |========================================= + """.stripMargin) + } + + override protected def beforeAll(): Unit = { + // Chooses a random port between 10000 and 19999 + listeningPort = 10000 + Random.nextInt(10000) + diagnosisBuffer.clear() + + // Retries up to 3 times with different port numbers if the server fails to start + (1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) => + started.orElse { + listeningPort += 1 + stopThriftServer() + Try(startThriftServer(listeningPort, attempt)) + } + }.recover { + case cause: Throwable => + dumpLogs() + throw cause + }.get + + logInfo(s"HiveThriftServer2 started successfully") + } + + override protected def afterAll(): Unit = { + stopThriftServer() + logInfo("HiveThriftServer2 stopped") + } +} From 53a1ebf33b5c349ae3a40d7eebf357b839b363af Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Feb 2015 18:51:41 -0800 Subject: [PATCH 17/28] [SPARK-5904][SQL] DataFrame Java API test suites. Added a new test suite to make sure Java DF programs can use varargs properly. Also moved all suites into test.org.apache.spark package to make sure the suites also test for method visibility. Author: Reynold Xin Closes #4751 from rxin/df-tests and squashes the following commits: 1e8b8e4 [Reynold Xin] Fixed imports and renamed JavaAPISuite. a6ca53b [Reynold Xin] [SPARK-5904][SQL] DataFrame Java API test suites. --- .../apache/spark/sql/api/java/JavaDsl.java | 120 ------------------ .../spark/sql}/JavaApplySchemaSuite.java | 28 ++-- .../apache/spark/sql/JavaDataFrameSuite.java | 84 ++++++++++++ .../org/apache/spark/sql}/JavaRowSuite.java | 2 +- .../org/apache/spark/sql/JavaUDFSuite.java} | 10 +- .../spark/sql/sources/JavaSaveLoadSuite.java | 3 +- .../org/apache/spark/sql/DataFrameSuite.scala | 4 +- 7 files changed, 108 insertions(+), 143 deletions(-) delete mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java rename sql/core/src/test/java/{org/apache/spark/sql/api/java => test/org/apache/spark/sql}/JavaApplySchemaSuite.java (90%) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java rename sql/core/src/test/java/{org/apache/spark/sql/api/java => test/org/apache/spark/sql}/JavaRowSuite.java (99%) rename sql/core/src/test/java/{org/apache/spark/sql/api/java/JavaAPISuite.java => test/org/apache/spark/sql/JavaUDFSuite.java} (94%) rename sql/core/src/test/java/{ => test}/org/apache/spark/sql/sources/JavaSaveLoadSuite.java (98%) diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java deleted file mode 100644 index 05233dc5ffc58..0000000000000 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.api.java; - -import com.google.common.collect.ImmutableMap; - -import org.apache.spark.sql.Column; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.types.DataTypes; - -import static org.apache.spark.sql.functions.*; - -/** - * This test doesn't actually run anything. It is here to check the API compatibility for Java. - */ -public class JavaDsl { - - public static void testDataFrame(final DataFrame df) { - DataFrame df1 = df.select("colA"); - df1 = df.select("colA", "colB"); - - df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1)); - - df1 = df.filter(col("colA")); - - java.util.Map aggExprs = ImmutableMap.builder() - .put("colA", "sum") - .put("colB", "avg") - .build(); - - df1 = df.agg(aggExprs); - - df1 = df.groupBy("groupCol").agg(aggExprs); - - df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer"); - - df.orderBy("colA"); - df.orderBy("colA", "colB", "colC"); - df.orderBy(col("colA").desc()); - df.orderBy(col("colA").desc(), col("colB").asc()); - - df.sort("colA"); - df.sort("colA", "colB", "colC"); - df.sort(col("colA").desc()); - df.sort(col("colA").desc(), col("colB").asc()); - - df.as("b"); - - df.limit(5); - - df.unionAll(df1); - df.intersect(df1); - df.except(df1); - - df.sample(true, 0.1, 234); - - df.head(); - df.head(5); - df.first(); - df.count(); - } - - public static void testColumn(final Column c) { - c.asc(); - c.desc(); - - c.endsWith("abcd"); - c.startsWith("afgasdf"); - - c.like("asdf%"); - c.rlike("wef%asdf"); - - c.as("newcol"); - - c.cast("int"); - c.cast(DataTypes.IntegerType); - } - - public static void testDsl() { - // Creating a column. - Column c = col("abcd"); - Column c1 = column("abcd"); - - // Literals - Column l1 = lit(1); - Column l2 = lit(1.0); - Column l3 = lit("abcd"); - - // Functions - Column a = upper(c); - a = lower(c); - a = sqrt(c); - a = abs(c); - - // Aggregates - a = min(c); - a = max(c); - a = sum(c); - a = sumDistinct(c); - a = countDistinct(c, a); - a = avg(c); - a = first(c); - a = last(c); - } -} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java similarity index 90% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 643b891ab1b63..c344a9b095c52 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.io.Serializable; import java.util.ArrayList; @@ -39,18 +39,18 @@ // see http://stackoverflow.com/questions/758570/. public class JavaApplySchemaSuite implements Serializable { private transient JavaSparkContext javaCtx; - private transient SQLContext javaSqlCtx; + private transient SQLContext sqlContext; @Before public void setUp() { - javaSqlCtx = TestSQLContext$.MODULE$; - javaCtx = new JavaSparkContext(javaSqlCtx.sparkContext()); + sqlContext = TestSQLContext$.MODULE$; + javaCtx = new JavaSparkContext(sqlContext.sparkContext()); } @After public void tearDown() { javaCtx = null; - javaSqlCtx = null; + sqlContext = null; } public static class Person implements Serializable { @@ -98,9 +98,9 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = javaSqlCtx.applySchema(rowRDD, schema); + DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); - Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect(); + Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); expected.add(RowFactory.create("Michael", 29)); @@ -109,8 +109,6 @@ public Row call(Person person) throws Exception { Assert.assertEquals(expected, Arrays.asList(actual)); } - - @Test public void dataFrameRDDOperations() { List personList = new ArrayList(2); @@ -135,9 +133,9 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = javaSqlCtx.applySchema(rowRDD, schema); + DataFrame df = sqlContext.applySchema(rowRDD, schema); df.registerTempTable("people"); - List actual = javaSqlCtx.sql("SELECT * FROM people").toJavaRDD().map(new Function() { + List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { public String call(Row row) { return row.getString(0) + "_" + row.get(1).toString(); @@ -189,18 +187,18 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD); + DataFrame df1 = sqlContext.jsonRDD(jsonRDD); StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); df1.registerTempTable("jsonTable1"); - List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList(); + List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); + DataFrame df2 = sqlContext.jsonRDD(jsonRDD, expectedSchema); StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); df2.registerTempTable("jsonTable2"); - List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList(); + List actual2 = sqlContext.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java new file mode 100644 index 0000000000000..c1c51f80d6586 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.*; +import org.apache.spark.sql.test.TestSQLContext$; +import static org.apache.spark.sql.functions.*; + + +public class JavaDataFrameSuite { + private transient SQLContext context; + + @Before + public void setUp() { + // Trigger static initializer of TestData + TestData$.MODULE$.testData(); + context = TestSQLContext$.MODULE$; + } + + @After + public void tearDown() { + context = null; + } + + @Test + public void testExecution() { + DataFrame df = context.table("testData").filter("key = 1"); + Assert.assertEquals(df.select("key").collect()[0].get(0), 1); + } + + /** + * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java. + */ + @Test + public void testVarargMethods() { + DataFrame df = context.table("testData"); + + df.toDF("key1", "value1"); + + df.select("key", "value"); + df.select(col("key"), col("value")); + df.selectExpr("key", "value + 1"); + + df.sort("key", "value"); + df.sort(col("key"), col("value")); + df.orderBy("key", "value"); + df.orderBy(col("key"), col("value")); + + df.groupBy("key", "value").agg(col("key"), col("value"), sum("value")); + df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value")); + df.agg(first("key"), sum("value")); + + df.groupBy().avg("key"); + df.groupBy().mean("key"); + df.groupBy().max("key"); + df.groupBy().min("key"); + df.groupBy().sum("key"); + + // Varargs in column expressions + df.groupBy().agg(countDistinct("key", "value")); + df.groupBy().agg(countDistinct(col("key"), col("value"))); + df.select(coalesce(col("key"))); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java similarity index 99% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java index fbfcd3f59d910..4ce1d1dddb26a 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.math.BigDecimal; import java.sql.Date; diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java similarity index 94% rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index a21a15409080c..79d92734ff375 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -15,24 +15,26 @@ * limitations under the License. */ -package org.apache.spark.sql.api.java; +package test.org.apache.spark.sql; import java.io.Serializable; -import org.apache.spark.sql.test.TestSQLContext$; import org.junit.After; import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.api.java.UDF1; +import org.apache.spark.sql.api.java.UDF2; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.test.TestSQLContext$; import org.apache.spark.sql.types.DataTypes; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; // see http://stackoverflow.com/questions/758570/. -public class JavaAPISuite implements Serializable { +public class JavaUDFSuite implements Serializable { private transient JavaSparkContext sc; private transient SQLContext sqlContext; diff --git a/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java similarity index 98% rename from sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java rename to sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java index 311f1bdd07510..b76f7d421f643 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/sources/JavaSaveLoadSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java @@ -14,7 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.sources; + +package test.org.apache.spark.sql.sources; import java.io.File; import java.io.IOException; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index e71e9bee3a6d8..30e77e4ef30f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -411,7 +411,7 @@ class DataFrameSuite extends QueryTest { ) } - test("addColumn") { + test("withColumn") { val df = testData.toDF().withColumn("newCol", col("key") + 1) checkAnswer( df, @@ -421,7 +421,7 @@ class DataFrameSuite extends QueryTest { assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol")) } - test("renameColumn") { + test("withColumnRenamed") { val df = testData.toDF().withColumn("newCol", col("key") + 1) .withColumnRenamed("value", "valueRenamed") checkAnswer( From fba11c2f55dd81e4f6230e7edca3c7b2e01ccd9d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 24 Feb 2015 18:59:23 -0800 Subject: [PATCH 18/28] [SPARK-5985][SQL] DataFrame sortBy -> orderBy in Python. Also added desc/asc function for constructing sorting expressions more conveniently. And added a small fix to lift alias out of cast expression. Author: Reynold Xin Closes #4752 from rxin/SPARK-5985 and squashes the following commits: aeda5ae [Reynold Xin] Added Experimental flag to ColumnName. 047ad03 [Reynold Xin] Lift alias out of cast. c9cf17c [Reynold Xin] [SPARK-5985][SQL] DataFrame sortBy -> orderBy in Python. --- python/pyspark/sql/dataframe.py | 11 +++++-- python/pyspark/sql/functions.py | 3 ++ .../scala/org/apache/spark/sql/Column.scala | 13 +++++++-- .../org/apache/spark/sql/functions.scala | 29 +++++++++++++++++++ .../spark/sql/ColumnExpressionSuite.scala | 4 +++ .../org/apache/spark/sql/DataFrameSuite.scala | 4 +++ 6 files changed, 59 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 010c38f93b9cf..6f746d136b22d 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -504,13 +504,18 @@ def join(self, other, joinExprs=None, joinType=None): return DataFrame(jdf, self.sql_ctx) def sort(self, *cols): - """ Return a new :class:`DataFrame` sorted by the specified column. + """ Return a new :class:`DataFrame` sorted by the specified column(s). :param cols: The columns or expressions used for sorting >>> df.sort(df.age.desc()).collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] - >>> df.sortBy(df.age.desc()).collect() + >>> df.orderBy(df.age.desc()).collect() + [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] + >>> from pyspark.sql.functions import * + >>> df.sort(asc("age")).collect() + [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] + >>> df.orderBy(desc("age"), "name").collect() [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')] """ if not cols: @@ -520,7 +525,7 @@ def sort(self, *cols): jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols)) return DataFrame(jdf, self.sql_ctx) - sortBy = sort + orderBy = sort def head(self, n=None): """ Return the first `n` rows or the first row if n is None. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index fc61162f0b827..8aa44765205c1 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -48,6 +48,9 @@ def _(col): 'lit': 'Creates a :class:`Column` of literal value.', 'col': 'Returns a :class:`Column` based on the given column name.', 'column': 'Returns a :class:`Column` based on the given column name.', + 'asc': 'Returns a sort expression based on the ascending order of the given column name.', + 'desc': 'Returns a sort expression based on the descending order of the given column name.', + 'upper': 'Converts a string expression to upper case.', 'lower': 'Converts a string expression to upper case.', 'sqrt': 'Computes the square root of the specified float value.', diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 980754322e6c8..a2cc9a9b93eb8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -600,7 +600,11 @@ class Column(protected[sql] val expr: Expression) { * * @group expr_ops */ - def cast(to: DataType): Column = Cast(expr, to) + def cast(to: DataType): Column = expr match { + // Lift alias out of cast so we can support col.as("name").cast(IntegerType) + case Alias(childExpr, name) => Alias(Cast(childExpr, to), name)() + case _ => Cast(expr, to) + } /** * Casts the column to a different data type, using the canonical string representation @@ -613,7 +617,7 @@ class Column(protected[sql] val expr: Expression) { * * @group expr_ops */ - def cast(to: String): Column = Cast(expr, to.toLowerCase match { + def cast(to: String): Column = cast(to.toLowerCase match { case "string" | "str" => StringType case "boolean" => BooleanType case "byte" => ByteType @@ -671,6 +675,11 @@ class Column(protected[sql] val expr: Expression) { } +/** + * :: Experimental :: + * A convenient class used for constructing schema. + */ +@Experimental class ColumnName(name: String) extends Column(name) { /** Creates a new AttributeReference of type boolean */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2a1e086891423..4fdbfc6d22c9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.types._ * * @groupname udf_funcs UDF functions * @groupname agg_funcs Aggregate functions + * @groupname sort_funcs Sorting functions * @groupname normal_funcs Non-aggregate functions * @groupname Ungrouped Support functions for DataFrames. */ @@ -96,6 +97,33 @@ object functions { } ////////////////////////////////////////////////////////////////////////////////////////////// + // Sort functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns a sort expression based on ascending order of the column. + * {{ + * // Sort by dept in ascending order, and then age in descending order. + * df.sort(asc("dept"), desc("age")) + * }} + * + * @group sort_funcs + */ + def asc(columnName: String): Column = Column(columnName).asc + + /** + * Returns a sort expression based on the descending order of the column. + * {{ + * // Sort by dept in ascending order, and then age in descending order. + * df.sort(asc("dept"), desc("age")) + * }} + * + * @group sort_funcs + */ + def desc(columnName: String): Column = Column(columnName).desc + + ////////////////////////////////////////////////////////////////////////////////////////////// + // Aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// /** @@ -263,6 +291,7 @@ object functions { def max(columnName: String): Column = max(Column(columnName)) ////////////////////////////////////////////////////////////////////////////////////////////// + // Non-aggregate functions ////////////////////////////////////////////////////////////////////////////////////////////// /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 928b0deb61921..37c02aaa5460b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -309,4 +309,8 @@ class ColumnExpressionSuite extends QueryTest { (1 to 100).map(n => Row(null)) ) } + + test("lift alias out of cast") { + assert(col("1234").as("name").cast("int").expr === col("1234").cast("int").as("name").expr) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 30e77e4ef30f2..c392a553c03f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -239,6 +239,10 @@ class DataFrameSuite extends QueryTest { testData2.orderBy('a.asc, 'b.asc), Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) + checkAnswer( + testData2.orderBy(asc("a"), desc("b")), + Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) + checkAnswer( testData2.orderBy('a.asc, 'b.desc), Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1))) From 922b43b3cc1cca04e0313bf9e31c5f944ac06d1f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 24 Feb 2015 19:10:37 -0800 Subject: [PATCH 19/28] [SPARK-5993][Streaming][Build] Fix assembly jar location of kafka-assembly Published Kafka-assembly JAR was empty in 1.3.0-RC1 This is because the maven build generated two Jars- 1. an empty JAR file (since kafka-assembly has no code of its own) 2. a assembly JAR file containing everything in a different location as 1 The maven publishing plugin uploaded 1 and not 2. Instead if 2 is not configure to generate in a different location, there is only 1 jar containing everything, which gets published. Author: Tathagata Das Closes #4753 from tdas/SPARK-5993 and squashes the following commits: c390db8 [Tathagata Das] Fix assembly jar location of kafka-assembly --- external/kafka-assembly/pom.xml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index 503fc129dc4f2..8daa7ed608f6a 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -33,9 +33,6 @@ streaming-kafka-assembly - scala-${scala.binary.version} - spark-streaming-kafka-assembly-${project.version}.jar - ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename} @@ -61,7 +58,6 @@ maven-shade-plugin false - ${spark.jar} *:* From 769e092bdc51582372093f76dbaece27149cc4ea Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 24 Feb 2015 19:51:36 -0800 Subject: [PATCH 20/28] [SPARK-5286][SQL] SPARK-5286 followup https://issues.apache.org/jira/browse/SPARK-5286 Author: Yin Huai Closes #4755 from yhuai/SPARK-5286-throwable and squashes the following commits: 4c0c450 [Yin Huai] Catch Throwable instead of Exception. --- .../org/apache/spark/sql/hive/execution/commands.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index c88d0e6b79491..9934a5d3c30a2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -63,10 +63,10 @@ case class DropTable( } catch { // This table's metadata is not in case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => - // Other exceptions can be caused by users providing wrong parameters in OPTIONS + // Other Throwables can be caused by users providing wrong parameters in OPTIONS // (e.g. invalid paths). We catch it and log a warning message. - // Users should be able to drop such kinds of tables regardless if there is an exception. - case e: Exception => log.warn(s"${e.getMessage}") + // Users should be able to drop such kinds of tables regardless if there is an error. + case e: Throwable => log.warn(s"${e.getMessage}") } hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") From d641fbb39c90b1d734cc55396ca43d7e98788975 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 24 Feb 2015 20:51:55 -0800 Subject: [PATCH 21/28] [SPARK-5994] [SQL] Python DataFrame documentation fixes select empty should NOT be the same as select. make sure selectExpr is behaving the same. join param documentation link to source doesn't work in jekyll generated file cross reference of columns (i.e. enabling linking) show(): move df example before df.show() move tests in SQLContext out of docstring otherwise doc is too long Column.desc and .asc doesn't have any documentation in documentation, sort functions.*) Author: Davies Liu Closes #4756 from davies/df_docs and squashes the following commits: f30502c [Davies Liu] fix doc 32f0d46 [Davies Liu] fix DataFrame docs --- docs/_config.yml | 1 + python/docs/pyspark.sql.rst | 3 - python/pyspark/sql/context.py | 182 ++++++-------------------------- python/pyspark/sql/dataframe.py | 56 +++++----- python/pyspark/sql/functions.py | 1 + python/pyspark/sql/tests.py | 68 +++++++++++- python/pyspark/sql/types.py | 2 +- 7 files changed, 130 insertions(+), 183 deletions(-) diff --git a/docs/_config.yml b/docs/_config.yml index e2db274e1f619..0652927a8ce9b 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -10,6 +10,7 @@ kramdown: include: - _static + - _modules # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst index e03379e521a07..2e3f69b9a562a 100644 --- a/python/docs/pyspark.sql.rst +++ b/python/docs/pyspark.sql.rst @@ -7,7 +7,6 @@ Module Context .. automodule:: pyspark.sql :members: :undoc-members: - :show-inheritance: pyspark.sql.types module @@ -15,7 +14,6 @@ pyspark.sql.types module .. automodule:: pyspark.sql.types :members: :undoc-members: - :show-inheritance: pyspark.sql.functions module @@ -23,4 +21,3 @@ pyspark.sql.functions module .. automodule:: pyspark.sql.functions :members: :undoc-members: - :show-inheritance: diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 125933c9d3ae0..5d7aeb664cadf 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -129,6 +129,7 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlCtx.sql("SELECT stringLengthString('test')").collect() [Row(c0=u'4')] + >>> from pyspark.sql.types import IntegerType >>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlCtx.sql("SELECT stringLengthInt('test')").collect() @@ -197,31 +198,6 @@ def inferSchema(self, rdd, samplingRatio=None): >>> df = sqlCtx.inferSchema(rdd) >>> df.collect()[0] Row(field1=1, field2=u'row1') - - >>> NestedRow = Row("f1", "f2") - >>> nestedRdd1 = sc.parallelize([ - ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), - ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> df = sqlCtx.inferSchema(nestedRdd1) - >>> df.collect() - [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] - - >>> nestedRdd2 = sc.parallelize([ - ... NestedRow([[1, 2], [2, 3]], [1, 2]), - ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> df = sqlCtx.inferSchema(nestedRdd2) - >>> df.collect() - [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] - - >>> from collections import namedtuple - >>> CustomRow = namedtuple('CustomRow', 'field1 field2') - >>> rdd = sc.parallelize( - ... [CustomRow(field1=1, field2="row1"), - ... CustomRow(field1=2, field2="row2"), - ... CustomRow(field1=3, field2="row3")]) - >>> df = sqlCtx.inferSchema(rdd) - >>> df.collect()[0] - Row(field1=1, field2=u'row1') """ if isinstance(rdd, DataFrame): @@ -252,56 +228,8 @@ def applySchema(self, rdd, schema): >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) >>> df = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerDataFrameAsTable(df, "table1") - >>> df2 = sqlCtx.sql("SELECT * from table1") - >>> df2.collect() - [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - - >>> from datetime import date, datetime - >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, - ... date(2010, 1, 1), - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3], None)]) - >>> schema = StructType([ - ... StructField("byte1", ByteType(), False), - ... StructField("byte2", ByteType(), False), - ... StructField("short1", ShortType(), False), - ... StructField("short2", ShortType(), False), - ... StructField("int1", IntegerType(), False), - ... StructField("float1", FloatType(), False), - ... StructField("date1", DateType(), False), - ... StructField("time1", TimestampType(), False), - ... StructField("map1", - ... MapType(StringType(), IntegerType(), False), False), - ... StructField("struct1", - ... StructType([StructField("b", ShortType(), False)]), False), - ... StructField("list1", ArrayType(ByteType(), False), False), - ... StructField("null1", DoubleType(), True)]) - >>> df = sqlCtx.applySchema(rdd, schema) - >>> results = df.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, - ... x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) - >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE - (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), - datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - - >>> df.registerTempTable("table2") - >>> sqlCtx.sql( - ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + - ... "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + - ... "float1 + 1.5 as float1 FROM table2").collect() - [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int1=2147483646, float1=2.5)] - - >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type - >>> rdd = sc.parallelize([(127, -32768, 1.0, - ... datetime(2010, 1, 1, 1, 1, 1), - ... {"a": 1}, (2,), [1, 2, 3])]) - >>> abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" - >>> schema = _parse_schema_abstract(abstract) - >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> df = sqlCtx.applySchema(rdd, typedSchema) >>> df.collect() - [Row(byte1=127, short1=-32768, float1=1.0, time1=..., list1=[1, 2, 3])] + [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] """ if isinstance(rdd, DataFrame): @@ -459,46 +387,28 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() >>> shutil.rmtree(jsonFile) - >>> ofn = open(jsonFile, 'w') - >>> for json in jsonStrings: - ... print>>ofn, json - >>> ofn.close() + >>> with open(jsonFile, 'w') as f: + ... f.writelines(jsonStrings) >>> df1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerDataFrameAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema) - >>> sqlCtx.registerDataFrameAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> df1.printSchema() + root + |-- field1: long (nullable = true) + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field4: long (nullable = true) >>> from pyspark.sql.types import * >>> schema = StructType([ - ... StructField("field2", StringType(), True), + ... StructField("field2", StringType()), ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerDataFrameAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] + ... StructType([StructField("field5", ArrayType(IntegerType()))]))]) + >>> df2 = sqlCtx.jsonFile(jsonFile, schema) + >>> df2.printSchema() + root + |-- field2: string (nullable = true) + |-- field3: struct (nullable = true) + | |-- field5: array (nullable = true) + | | |-- element: integer (containsNull = true) """ if schema is None: df = self._ssql_ctx.jsonFile(path, samplingRatio) @@ -517,48 +427,23 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): determine the schema. >>> df1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerDataFrameAsTable(df1, "table1") - >>> df2 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table1") - >>> for r in df2.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - - >>> df3 = sqlCtx.jsonRDD(json, df1.schema) - >>> sqlCtx.registerDataFrameAsTable(df3, "table2") - >>> df4 = sqlCtx.sql( - ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " - ... "field6 as f4 from table2") - >>> for r in df4.collect(): - ... print r - Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) - Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) - Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) + >>> df1.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) + + >>> df2 = sqlCtx.jsonRDD(json, df1.schema) + >>> df2.first() + Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) >>> from pyspark.sql.types import * >>> schema = StructType([ - ... StructField("field2", StringType(), True), + ... StructField("field2", StringType()), ... StructField("field3", - ... StructType([ - ... StructField("field5", - ... ArrayType(IntegerType(), False), True)]), False)]) - >>> df5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerDataFrameAsTable(df5, "table3") - >>> df6 = sqlCtx.sql( - ... "SELECT field2 AS f1, field3.field5 as f2, " - ... "field3.field5[0] as f3 from table3") - >>> df6.collect() - [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] - - >>> sqlCtx.jsonRDD(sc.parallelize(['{}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] - >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}', - ... '{"key0": {"key1": "value1"}}'])).collect() - [Row(key0=None), Row(key0=Row(key1=u'value1'))] + ... StructType([StructField("field5", ArrayType(IntegerType()))])) + ... ]) + >>> df3 = sqlCtx.jsonRDD(json, schema) + >>> df3.first() + Row(field2=u'row1', field3=Row(field5=None)) + """ def func(iterator): @@ -848,7 +733,8 @@ def _test(): globs['jsonStrings'] = jsonStrings globs['json'] = sc.parallelize(jsonStrings) (failure_count, test_count) = doctest.testmod( - pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS) + pyspark.sql.context, globs=globs, + optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE) globs['sc'].stop() if failure_count: exit(-1) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 6f746d136b22d..6d42410020b64 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -96,7 +96,7 @@ def applySchema(it): return self._lazy_rdd def toJSON(self, use_unicode=False): - """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. + """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row. >>> df.toJSON().first() '{"age":2,"name":"Alice"}' @@ -108,7 +108,7 @@ def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a DataFrame using the L{SQLContext.parquetFile} method. + a :class:`DataFrame` using the L{SQLContext.parquetFile} method. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() @@ -139,7 +139,7 @@ def registerAsTable(self, name): self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this DataFrame into the specified table. + """Inserts the contents of this :class:`DataFrame` into the specified table. Optionally overwriting any existing data. """ @@ -165,7 +165,7 @@ def _java_save_mode(self, mode): return jmode def saveAsTable(self, tableName, source=None, mode="append", **options): - """Saves the contents of the DataFrame to a data source as a table. + """Saves the contents of the :class:`DataFrame` to a data source as a table. The data source is specified by the `source` and a set of `options`. If `source` is not specified, the default data source configured by @@ -174,12 +174,13 @@ def saveAsTable(self, tableName, source=None, mode="append", **options): Additionally, mode is used to specify the behavior of the saveAsTable operation when table already exists in the data source. There are four modes: - * append: Contents of this DataFrame are expected to be appended to existing table. - * overwrite: Data in the existing table is expected to be overwritten by the contents of \ - this DataFrame. + * append: Contents of this :class:`DataFrame` are expected to be appended \ + to existing table. + * overwrite: Data in the existing table is expected to be overwritten by \ + the contents of this DataFrame. * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of the DataFrame and \ - to not change the existing table. + * ignore: The save operation is expected to not save the contents of the \ + :class:`DataFrame` and to not change the existing table. """ if source is None: source = self.sql_ctx.getConf("spark.sql.sources.default", @@ -190,7 +191,7 @@ def saveAsTable(self, tableName, source=None, mode="append", **options): self._jdf.saveAsTable(tableName, source, jmode, joptions) def save(self, path=None, source=None, mode="append", **options): - """Saves the contents of the DataFrame to a data source. + """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the `source` and a set of `options`. If `source` is not specified, the default data source configured by @@ -199,11 +200,11 @@ def save(self, path=None, source=None, mode="append", **options): Additionally, mode is used to specify the behavior of the save operation when data already exists in the data source. There are four modes: - * append: Contents of this DataFrame are expected to be appended to existing data. + * append: Contents of this :class:`DataFrame` are expected to be appended to existing data. * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. * error: An exception is expected to be thrown. - * ignore: The save operation is expected to not save the contents of the DataFrame and \ - to not change the existing data. + * ignore: The save operation is expected to not save the contents of \ + the :class:`DataFrame` and to not change the existing data. """ if path is not None: options["path"] = path @@ -217,7 +218,7 @@ def save(self, path=None, source=None, mode="append", **options): @property def schema(self): - """Returns the schema of this DataFrame (represented by + """Returns the schema of this :class:`DataFrame` (represented by a L{StructType}). >>> df.schema @@ -275,12 +276,12 @@ def show(self): """ Print the first 20 rows. + >>> df + DataFrame[age: int, name: string] >>> df.show() age name 2 Alice 5 Bob - >>> df - DataFrame[age: int, name: string] """ print self._jdf.showString().encode('utf8', 'ignore') @@ -481,8 +482,8 @@ def columns(self): def join(self, other, joinExprs=None, joinType=None): """ - Join with another DataFrame, using the given join expression. - The following performs a full outer join between `df1` and `df2`:: + Join with another :class:`DataFrame`, using the given join expression. + The following performs a full outer join between `df1` and `df2`. :param other: Right side of the join :param joinExprs: Join expression @@ -582,8 +583,6 @@ def __getattr__(self, name): def select(self, *cols): """ Selecting a set of expressions. - >>> df.select().collect() - [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('*').collect() [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] >>> df.select('name', 'age').collect() @@ -591,8 +590,6 @@ def select(self, *cols): >>> df.select(df.name, (df.age + 10).alias('age')).collect() [Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)] """ - if not cols: - cols = ["*"] jcols = ListConverter().convert([_to_java_column(c) for c in cols], self._sc._gateway._gateway_client) jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols)) @@ -612,7 +609,7 @@ def selectExpr(self, *expr): def filter(self, condition): """ Filtering rows using the given condition, which could be - Column expression or string of SQL expression. + :class:`Column` expression or string of SQL expression. where() is an alias for filter(). @@ -666,7 +663,7 @@ def agg(self, *exprs): return self.groupBy().agg(*exprs) def unionAll(self, other): - """ Return a new DataFrame containing union of rows in this + """ Return a new :class:`DataFrame` containing union of rows in this frame and another frame. This is equivalent to `UNION ALL` in SQL. @@ -919,9 +916,10 @@ class Column(object): """ A column in a DataFrame. - `Column` instances can be created by:: + :class:`Column` instances can be created by:: # 1. Select a column out of a DataFrame + df.colName df["colName"] @@ -975,7 +973,7 @@ def __init__(self, jc): def substr(self, startPos, length): """ - Return a Column which is a substring of the column + Return a :class:`Column` which is a substring of the column :param startPos: start position (int or Column) :param length: length of the substring (int or Column) @@ -996,8 +994,10 @@ def substr(self, startPos, length): __getslice__ = substr # order - asc = _unary_op("asc") - desc = _unary_op("desc") + asc = _unary_op("asc", "Returns a sort expression based on the" + " ascending order of the given column name.") + desc = _unary_op("desc", "Returns a sort expression based on the" + " descending order of the given column name.") isNull = _unary_op("isNull", "True if the current expression is null.") isNotNull = _unary_op("isNotNull", "True if the current expression is not null.") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 8aa44765205c1..5873f09ae3275 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -72,6 +72,7 @@ def _(col): globals()[_name] = _create_function(_name, _doc) del _name, _doc __all__ += _functions.keys() +__all__.sort() def countDistinct(col, *cols): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 39071e7e35ca1..83899ad4b1b12 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -36,9 +36,9 @@ else: import unittest -from pyspark.sql import SQLContext, HiveContext, Column -from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \ - UserDefinedType, DoubleType, LongType, StringType, _infer_type +from pyspark.sql import SQLContext, HiveContext, Column, Row +from pyspark.sql.types import * +from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase @@ -204,6 +204,68 @@ def test_infer_schema(self): result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") self.assertEqual(1, result.head()[0]) + def test_infer_nested_schema(self): + NestedRow = Row("f1", "f2") + nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), + NestedRow([2, 3], {"row2": 2.0})]) + df = self.sqlCtx.inferSchema(nestedRdd1) + self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) + + nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), + NestedRow([[2, 3], [3, 4]], [2, 3])]) + df = self.sqlCtx.inferSchema(nestedRdd2) + self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) + + from collections import namedtuple + CustomRow = namedtuple('CustomRow', 'field1 field2') + rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), + CustomRow(field1=2, field2="row2"), + CustomRow(field1=3, field2="row3")]) + df = self.sqlCtx.inferSchema(rdd) + self.assertEquals(Row(field1=1, field2=u'row1'), df.first()) + + def test_apply_schema(self): + from datetime import date, datetime + rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3], None)]) + schema = StructType([ + StructField("byte1", ByteType(), False), + StructField("byte2", ByteType(), False), + StructField("short1", ShortType(), False), + StructField("short2", ShortType(), False), + StructField("int1", IntegerType(), False), + StructField("float1", FloatType(), False), + StructField("date1", DateType(), False), + StructField("time1", TimestampType(), False), + StructField("map1", MapType(StringType(), IntegerType(), False), False), + StructField("struct1", StructType([StructField("b", ShortType(), False)]), False), + StructField("list1", ArrayType(ByteType(), False), False), + StructField("null1", DoubleType(), True)]) + df = self.sqlCtx.applySchema(rdd, schema) + results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1, + x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1)) + r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1), + datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + self.assertEqual(r, results.first()) + + df.registerTempTable("table2") + r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + + "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " + + "float1 + 1.5 as float1 FROM table2").first() + + self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) + + from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type + rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), + {"a": 1}, (2,), [1, 2, 3])]) + abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]" + schema = _parse_schema_abstract(abstract) + typedSchema = _infer_schema_type(rdd.first(), schema) + df = self.sqlCtx.applySchema(rdd, typedSchema) + r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3]) + self.assertEqual(r, tuple(df.first())) + def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] df = self.sc.parallelize(d).toDF() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b6e41cf0b29ff..0f5dc2be6dab8 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -28,7 +28,7 @@ __all__ = [ "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", - "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", ] + "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"] class DataType(object): From d51ed263ee791967380de6b9c892985ce87f6fcb Mon Sep 17 00:00:00 2001 From: prabs Date: Wed, 25 Feb 2015 14:37:35 +0000 Subject: [PATCH 22/28] [SPARK-5666][streaming][MQTT streaming] some trivial fixes modified to adhere to accepted coding standards as pointed by tdas in PR #3844 Author: prabs Author: Prabeesh K Closes #4178 from prabeesh/master and squashes the following commits: bd2cb49 [Prabeesh K] adress the comment ccc0765 [prabs] adress the comment 46f9619 [prabs] adress the comment c035bdc [prabs] adress the comment 22dd7f7 [prabs] address the comments 0cc67bd [prabs] adress the comment 838c38e [prabs] adress the comment cd57029 [prabs] address the comments 66919a3 [Prabeesh K] changed MqttDefaultFilePersistence to MemoryPersistence 5857989 [prabs] modified to adhere to accepted coding standards --- .../examples/streaming/MQTTWordCount.scala | 49 +++++++++++-------- .../streaming/mqtt/MQTTInputDStream.scala | 26 +++++----- .../spark/streaming/mqtt/MQTTUtils.scala | 3 +- .../streaming/mqtt/MQTTStreamSuite.scala | 12 ++--- 4 files changed, 50 insertions(+), 40 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala index 6ff0c47793a25..f40caad322f59 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala @@ -17,8 +17,8 @@ package org.apache.spark.examples.streaming -import org.eclipse.paho.client.mqttv3.{MqttClient, MqttClientPersistence, MqttException, MqttMessage, MqttTopic} -import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence +import org.eclipse.paho.client.mqttv3._ +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} @@ -31,8 +31,6 @@ import org.apache.spark.SparkConf */ object MQTTPublisher { - var client: MqttClient = _ - def main(args: Array[String]) { if (args.length < 2) { System.err.println("Usage: MQTTPublisher ") @@ -42,25 +40,36 @@ object MQTTPublisher { StreamingExamples.setStreamingLogLevels() val Seq(brokerUrl, topic) = args.toSeq + + var client: MqttClient = null try { - var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp") - client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance) + val persistence = new MemoryPersistence() + client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) + + client.connect() + + val msgtopic = client.getTopic(topic) + val msgContent = "hello mqtt demo for spark streaming" + val message = new MqttMessage(msgContent.getBytes("utf-8")) + + while (true) { + try { + msgtopic.publish(message) + println(s"Published data. topic: {msgtopic.getName()}; Message: {message}") + } catch { + case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT => + Thread.sleep(10) + println("Queue is full, wait for to consume data from the message queue") + } + } } catch { case e: MqttException => println("Exception Caught: " + e) + } finally { + if (client != null) { + client.disconnect() + } } - - client.connect() - - val msgtopic: MqttTopic = client.getTopic(topic) - val msg: String = "hello mqtt demo for spark streaming" - - while (true) { - val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes("utf-8")) - msgtopic.publish(message) - println("Published data. topic: " + msgtopic.getName() + " Message: " + message) - } - client.disconnect() } } @@ -96,9 +105,9 @@ object MQTTWordCount { val sparkConf = new SparkConf().setAppName("MQTTWordCount") val ssc = new StreamingContext(sparkConf, Seconds(2)) val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2) - - val words = lines.flatMap(x => x.toString.split(" ")) + val words = lines.flatMap(x => x.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + wordCounts.print() ssc.start() ssc.awaitTermination() diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala index 1ef91dd49284f..3c0ef94cb0fab 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala @@ -17,23 +17,23 @@ package org.apache.spark.streaming.mqtt +import java.io.IOException +import java.util.concurrent.Executors +import java.util.Properties + +import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.HashMap -import scala.collection.JavaConversions._ import scala.reflect.ClassTag -import java.util.Properties -import java.util.concurrent.Executors -import java.io.IOException - +import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttCallback import org.eclipse.paho.client.mqttv3.MqttClient import org.eclipse.paho.client.mqttv3.MqttClientPersistence -import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence -import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken import org.eclipse.paho.client.mqttv3.MqttException import org.eclipse.paho.client.mqttv3.MqttMessage import org.eclipse.paho.client.mqttv3.MqttTopic +import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel @@ -82,18 +82,18 @@ class MQTTReceiver( val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence) // Callback automatically triggers as and when new message arrives on specified topic - val callback: MqttCallback = new MqttCallback() { + val callback = new MqttCallback() { // Handles Mqtt message - override def messageArrived(arg0: String, arg1: MqttMessage) { - store(new String(arg1.getPayload(),"utf-8")) + override def messageArrived(topic: String, message: MqttMessage) { + store(new String(message.getPayload(),"utf-8")) } - override def deliveryComplete(arg0: IMqttDeliveryToken) { + override def deliveryComplete(token: IMqttDeliveryToken) { } - override def connectionLost(arg0: Throwable) { - restart("Connection lost ", arg0) + override def connectionLost(cause: Throwable) { + restart("Connection lost ", cause) } } diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala index c5ffe51f9986c..1142d0f56ba34 100644 --- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala +++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.mqtt +import scala.reflect.ClassTag + import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream} -import scala.reflect.ClassTag import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} object MQTTUtils { diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala index 19c9271af77be..0f3298af6234a 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala @@ -42,8 +42,8 @@ import org.apache.spark.util.Utils class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { private val batchDuration = Milliseconds(500) - private val master: String = "local[2]" - private val framework: String = this.getClass.getSimpleName + private val master = "local[2]" + private val framework = this.getClass.getSimpleName private val freePort = findFreePort() private val brokerUri = "//localhost:" + freePort private val topic = "def" @@ -69,7 +69,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { test("mqtt input stream") { val sendMessage = "MQTT demo for spark streaming" - val receiveStream: ReceiverInputDStream[String] = + val receiveStream = MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY) @volatile var receiveMessage: List[String] = List() receiveStream.foreachRDD { rdd => @@ -123,12 +123,12 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter { def publishData(data: String): Unit = { var client: MqttClient = null try { - val persistence: MqttClientPersistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) + val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath) client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence) client.connect() if (client.isConnected) { - val msgTopic: MqttTopic = client.getTopic(topic) - val message: MqttMessage = new MqttMessage(data.getBytes("utf-8")) + val msgTopic = client.getTopic(topic) + val message = new MqttMessage(data.getBytes("utf-8")) message.setQos(1) message.setRetained(true) From 5b8480e0359d5af8bdf570f115acb0b1b8997735 Mon Sep 17 00:00:00 2001 From: Benedikt Linse Date: Wed, 25 Feb 2015 14:46:17 +0000 Subject: [PATCH 23/28] [GraphX] fixing 3 typos in the graphx programming guide Corrected 3 Typos in the GraphX programming guide. I hope this is the correct way to contribute. Author: Benedikt Linse Closes #4766 from 1123/master and squashes the following commits: 8a63812 [Benedikt Linse] fixing 3 typos in the graphx programming guide --- docs/graphx-programming-guide.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 826f6d8f371c7..28bdf81ca0ca5 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -538,7 +538,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts, ## Neighborhood Aggregation -A key step in may graph analytics tasks is aggregating information about the neighborhood of each +A key step in many graph analytics tasks is aggregating information about the neighborhood of each vertex. For example, we might want to know the number of followers each user has or the average age of the the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and @@ -634,7 +634,7 @@ avgAgeOfOlderFollowers.collect.foreach(println(_)) ### Map Reduce Triplets Transition Guide (Legacy) -In earlier versions of GraphX we neighborhood aggregation was accomplished using the +In earlier versions of GraphX neighborhood aggregation was accomplished using the [`mapReduceTriplets`][Graph.mapReduceTriplets] operator: {% highlight scala %} @@ -682,8 +682,8 @@ val result = graph.aggregateMessages[String](msgFun, reduceFun) ### Computing Degree Information A common aggregation task is computing the degree of each vertex: the number of edges adjacent to -each vertex. In the context of directed graphs it often necessary to know the in-degree, out- -degree, and the total degree of each vertex. The [`GraphOps`][GraphOps] class contains a +each vertex. In the context of directed graphs it is often necessary to know the in-degree, +out-degree, and the total degree of each vertex. The [`GraphOps`][GraphOps] class contains a collection of operators to compute the degrees of each vertex. For example in the following we compute the max in, out, and total degrees: From dd077abf2e2949fdfec31074b760b587f00efcf2 Mon Sep 17 00:00:00 2001 From: guliangliang Date: Wed, 25 Feb 2015 14:48:02 +0000 Subject: [PATCH 24/28] [SPARK-5771] Number of Cores in Completed Applications of Standalone Master Web Page always be 0 if sc.stop() is called In Standalone mode, the number of cores in Completed Applications of the Master Web Page will always be zero, if sc.stop() is called. But the number will always be right, if sc.stop() is not called. The reason maybe: after sc.stop() is called, the function removeExecutor of class ApplicationInfo will be called, thus reduce the variable coresGranted to zero. The variable coresGranted is used to display the number of Cores on the Web Page. Author: guliangliang Closes #4567 from marsishandsome/Spark5771 and squashes the following commits: 694796e [guliangliang] remove duplicate code a20e390 [guliangliang] change to Cores Using & Requested 0c19c95 [guliangliang] change Cores to Cores (max) cfbd97d [guliangliang] [SPARK-5771] Number of Cores in Completed Applications of Standalone Master Web Page always be 0 if sc.stop() is called --- .../spark/deploy/master/ApplicationInfo.scala | 4 +-- .../spark/deploy/master/ui/MasterPage.scala | 31 +++++++++++++++---- 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ede0a9dbefb8d..a962dc4af2f6c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -90,9 +90,9 @@ private[spark] class ApplicationInfo( } } - private val myMaxCores = desc.maxCores.getOrElse(defaultCores) + val requestedCores = desc.maxCores.getOrElse(defaultCores) - def coresLeft: Int = myMaxCores - coresGranted + def coresLeft: Int = requestedCores - coresGranted private var _retryCount = 0 diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index fd514f07664a9..9dd96493ee48d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -50,12 +50,16 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val workers = state.workers.sortBy(_.id) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", - "User", "State", "Duration") + val activeAppHeaders = Seq("Application ID", "Name", "Cores in Use", + "Cores Requested", "Memory per Node", "Submitted Time", "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse - val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) + val activeAppsTable = UIUtils.listingTable(activeAppHeaders, activeAppRow, activeApps) + + val completedAppHeaders = Seq("Application ID", "Name", "Cores Requested", "Memory per Node", + "Submitted Time", "User", "State", "Duration") val completedApps = state.completedApps.sortBy(_.endTime).reverse - val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps) + val completedAppsTable = UIUtils.listingTable(completedAppHeaders, completeAppRow, + completedApps) val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores", "Memory", "Main Class") @@ -162,7 +166,7 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } - private def appRow(app: ApplicationInfo): Seq[Node] = { + private def appRow(app: ApplicationInfo, active: Boolean): Seq[Node] = { {app.id} @@ -170,8 +174,15 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { {app.desc.name} + { + if (active) { + + {app.coresGranted} + + } + } - {app.coresGranted} + {app.requestedCores} {Utils.megabytesToString(app.desc.memoryPerSlave)} @@ -183,6 +194,14 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } + private def activeAppRow(app: ApplicationInfo): Seq[Node] = { + appRow(app, active = true) + } + + private def completeAppRow(app: ApplicationInfo): Seq[Node] = { + appRow(app, active = false) + } + private def driverRow(driver: DriverInfo): Seq[Node] = { {driver.id} From f84c799ea0b82abca6a4fad39532c2515743b632 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 25 Feb 2015 10:13:40 -0800 Subject: [PATCH 25/28] [SPARK-5996][SQL] Fix specialized outbound conversions Author: Michael Armbrust Closes #4757 from marmbrus/udtConversions and squashes the following commits: 3714aad [Michael Armbrust] [SPARK-5996][SQL] Fix specialized outbound conversions --- .../apache/spark/sql/execution/LocalTableScan.scala | 7 +++++-- .../apache/spark/sql/execution/basicOperators.scala | 8 +++++--- .../org/apache/spark/sql/UserDefinedTypeSuite.scala | 10 ++++++++++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index d6d8258f46a9a..d3a18b37d52b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Attribute @@ -30,7 +31,9 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo override def execute() = rdd - override def executeCollect() = rows.toArray + override def executeCollect() = + rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray - override def executeTake(limit: Int) = rows.take(limit).toArray + override def executeTake(limit: Int) = + rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 4dc506c21ab9e..710268584cff1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -134,13 +134,15 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) val ord = new RowOrdering(sortOrder, child.output) + private def collectData() = child.execute().map(_.copy()).takeOrdered(limit)(ord) + // TODO: Is this copying for no reason? - override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord) - .map(ScalaReflection.convertRowToScala(_, this.schema)) + override def executeCollect() = + collectData().map(ScalaReflection.convertRowToScala(_, this.schema)) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sparkContext.makeRDD(executeCollect(), 1) + override def execute() = sparkContext.makeRDD(collectData(), 1) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 9c098df24c65f..47fdb5543235c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -22,6 +22,7 @@ import java.io.File import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -105,4 +106,13 @@ class UserDefinedTypeSuite extends QueryTest { tempDir.delete() pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath) } + + // Tests to make sure that all operators correctly convert types on the way out. + test("Local UDTs") { + val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") + df.collect()(0).getAs[MyDenseVector](1) + df.take(1)(0).getAs[MyDenseVector](1) + df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + } } From 7d8e6a2e44e13a6d6cdcd98a0d0c33b243ef0dc2 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 25 Feb 2015 12:20:44 -0800 Subject: [PATCH 26/28] SPARK-5930 [DOCS] Documented default of spark.shuffle.io.retryWait is confusing Clarify default max wait in spark.shuffle.io.retryWait docs CC andrewor14 Author: Sean Owen Closes #4769 from srowen/SPARK-5930 and squashes the following commits: ae2792b [Sean Owen] Clarify default max wait in spark.shuffle.io.retryWait docs --- docs/configuration.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index c8db338cb6f89..81298514a7cb2 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -955,7 +955,9 @@ Apart from these, the following properties are also available, and may be useful 5 (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying - is simply maxRetries * retryWait, by default 15 seconds. + is simply maxRetries * retryWait. The default maximum delay is therefore + 15 seconds, because the default value of maxRetries is 3, and the default + retryWait here is 5 seconds. From a777c65da9bc636e5cf5426e15a2e76d6b21b744 Mon Sep 17 00:00:00 2001 From: Milan Straka Date: Wed, 25 Feb 2015 21:33:34 +0000 Subject: [PATCH 27/28] [SPARK-5970][core] Register directory created in getOrCreateLocalRootDirs for automatic deletion. As documented in createDirectory, the result of createDirectory is not registered for automatic removal. Currently there are 4 directories left in `/tmp` after just running `pyspark`. Author: Milan Straka Closes #4759 from foxik/remove-tmp-dirs and squashes the following commits: 280450d [Milan Straka] Use createTempDir in getOrCreateLocalRootDirs... --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index df21ed37e76b1..4803ff9403b1d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -696,7 +696,7 @@ private[spark] object Utils extends Logging { try { val rootDir = new File(root) if (rootDir.exists || rootDir.mkdirs()) { - val dir = createDirectory(root) + val dir = createTempDir(root) chmod700(dir) Some(dir.getAbsolutePath) } else { From 9f603fce78fcc997926e9a72dec44d48cbc396fc Mon Sep 17 00:00:00 2001 From: Brennon York Date: Wed, 25 Feb 2015 14:11:12 -0800 Subject: [PATCH 28/28] [SPARK-1955][GraphX]: VertexRDD can incorrectly assume index sharing Fixes the issue whereby when VertexRDD's are `diff`ed, `innerJoin`ed, or `leftJoin`ed and have different partition sizes they fail under the `zipPartitions` method. This fix tests whether the partitions are equal or not and, if not, will repartition the other to match the partition size of the calling VertexRDD. Author: Brennon York Closes #4705 from brennonyork/SPARK-1955 and squashes the following commits: 0882590 [Brennon York] updated to properly handle differently-partitioned vertexRDDs --- .../org/apache/spark/graphx/impl/VertexRDDImpl.scala | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 6dad167fa7411..904be213147dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -104,8 +104,14 @@ class VertexRDDImpl[VD] private[graphx] ( this.mapVertexPartitions(_.map(f)) override def diff(other: VertexRDD[VD]): VertexRDD[VD] = { + val otherPartition = other match { + case other: VertexRDD[_] if this.partitioner == other.partitioner => + other.partitionsRDD + case _ => + VertexRDD(other.partitionBy(this.partitioner.get)).partitionsRDD + } val newPartitionsRDD = partitionsRDD.zipPartitions( - other.partitionsRDD, preservesPartitioning = true + otherPartition, preservesPartitioning = true ) { (thisIter, otherIter) => val thisPart = thisIter.next() val otherPart = otherIter.next() @@ -133,7 +139,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient leftZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => leftZipJoin(other)(f) case _ => this.withPartitionsRDD[VD3]( @@ -162,7 +168,7 @@ class VertexRDDImpl[VD] private[graphx] ( // Test if the other vertex is a VertexRDD to choose the optimal join strategy. // If the other set is a VertexRDD then we use the much more efficient innerZipJoin other match { - case other: VertexRDD[_] => + case other: VertexRDD[_] if this.partitioner == other.partitioner => innerZipJoin(other)(f) case _ => this.withPartitionsRDD(